[
  {
    "path": ".github/workflows/publish.yaml",
    "content": "name: release\n\non:\n  push:\n    tags:\n      - 'v**'\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}-publish\n  cancel-in-progress: true\n\njobs:\n  build-n-publish:\n    runs-on: ubuntu-20.04\n    #if: startsWith(github.event.ref, 'refs/tags')\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python 3.10\n        uses: actions/setup-python@v2\n        with:\n          python-version: '3.10'\n      - name: Install wheel\n        run: pip install wheel==0.44.0 && pip install -r requirements.txt\n      - name: Build DiffSynth\n        run: python -m build\n      - name: Publish package to PyPI\n        run: |\n          pip install twine\n          twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "/data\n/models\n/scripts\n/diffusers\n/.vscode\n*.pkl\n*.safetensors\n*.pth\n*.ckpt\n*.pt\n*.bin\n*.DS_Store\n*.msc\n*.mv\nlog*.txt\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2023] [Zhongjie Duan]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# DiffSynth-Studio\n\n<a href=\"https://github.com/modelscope/DiffSynth-Studio\"><img src=\".github/workflows/logo.gif\" title=\"Logo\" style=\"max-width:100%;\" width=\"55\" /></a> <a href=\"https://trendshift.io/repositories/10946\" target=\"_blank\"><img src=\"https://trendshift.io/api/badge/repositories/10946\" alt=\"modelscope%2FDiffSynth-Studio | Trendshift\" style=\"width: 250px; height: 55px;\" width=\"250\" height=\"55\"/></a></p>\n\n[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)\n[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)\n[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)\n[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)\n[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)\n\n[切换到中文版](./README_zh.md)\n\n## Introduction\n\n> DiffSynth-Studio Documentation: [中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)\n\nWelcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!\n\nDiffSynth currently includes two open-source projects:\n* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support.\n* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features.\n\n[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features:\n\n* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home\n* ModelScope Civision (for global users): https://modelscope.ai/civision/home\n\nWe believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.\n\n## Update History\n\n> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.\n\n> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.\n- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.\n\n- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).\n\n- **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities.\n\n- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates.\n\n<details>\n<summary>More</summary>\n\n- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details.\n\n- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future.\n\n- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.\n\n- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).\n\n- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.\n\n- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).\n\n- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).\n\n- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).\n\n- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online\n  - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated\n  - [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously\n  - New model support\n    - Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/)\n    - FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/)\n  - Training framework upgrade\n    - [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement.\n    - [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.\n    - [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.\n\n- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.\n\n- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.\n\n- **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem.\n\n- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).\n\n- **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions.\n\n- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/).\n\n- **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).\n\n- **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).\n\n- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)\n\n- **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!\n\n- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).\n\n- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing!\n\n- **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).\n\n- **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).\n\n- **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models.\n\n- **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).\n\n- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.\n\n- **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!\n\n- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).\n\n- **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/).\n\n- **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks.\n  - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)\n  - GitHub Repository: https://github.com/modelscope/Nexus-Gen\n  - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)\n  - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)\n  - Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)\n\n- **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out.\n\n- **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features.\n\n- **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/).\n\n- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).\n\n- **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).\n\n- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).\n\n- **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).\n  - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)\n  - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)\n  - Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)\n  - Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)\n\n- **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).\n\n- **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images.\n  - Paper: https://arxiv.org/abs/2412.12888\n  - Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug\n  - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)\n  - Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon)\n\n- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).\n\n- **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).\n\n- **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including:\n  - Text-to-video\n  - Video editing\n  - Self super-resolution\n  - Video interpolation\n\n- **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush!\n  - Use it in our [WebUI](#usage-in-webui).\n\n- **August 21, 2024** DiffSynth-Studio now supports FLUX.\n  - Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md)\n  - LoRA, ControlNet, and other addon models will be released soon.\n\n- **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames.\n  - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)\n  - Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/).\n  - Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).\n  - Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130).\n  - You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!\n\n- **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from \"me\" to \"us\". Of course, I will still participate in subsequent development and maintenance work.\n\n- **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution.\n  - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)\n  - Source code has been released in this project.\n  - Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224).\n\n- **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began.\n\n- **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm.\n  - sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).\n  - Demonstration videos have been showcased on Bilibili, including three tasks:\n    - [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)\n    - [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p)\n    - [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)\n  - Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265).\n  - Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend).\n\n- **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine.\n  - Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL).\n  - FastSDXL includes a trainable OLSS scheduler to improve efficiency.\n    - The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).\n    - Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677).\n    - Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).\n    - Since OLSS requires additional training, we did not implement it in this project.\n\n- **August 29, 2023** We propose DiffSynth, a video synthesis framework.\n  - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).\n  - Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).\n  - Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463).\n\n</details>\n\n## Installation\n\nInstall from source (recommended):\n\n```\ngit clone https://github.com/modelscope/DiffSynth-Studio.git  \ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more installation methods and instructions for non-NVIDIA GPUs, please refer to the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md).\n\n</details>\n\n## Basic Framework\n\nDiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.\n\n<details>\n<summary>Environment Variable Configuration</summary>\n\n> Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md).\n>\n> By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows:\n>\n> ```python\n> import os\n> os.environ[\"MODELSCOPE_DOMAIN\"] = \"www.modelscope.ai\"\n> ```\n>\n> To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source).\n\n</details>\n\n### Image Synthesis\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\n#### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory.\n\n```python\nfrom diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)\n\n|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|\n|-|-|-|-|-|-|-|\n|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|\n|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|\n|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|\n\n</details>\n\n#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM.\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)\n\n| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |\n|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|\n|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|\n|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|\n|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|\n|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|\n\n</details>\n\n#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRun the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM.\n\n```python\nfrom diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for Anima is located at: [/examples/anima/](/examples/anima/)\n\n| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training |\n|-|-|-|-|-|-|-|\n|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|\n\n</details>\n\n#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/)\n\n| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |\n|-|-|-|-|-|-|-|\n|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|\n|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|\n|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|\n|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|\n|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|\n|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|\n|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|\n|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|\n|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|\n|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|\n\n</details>\n\n#### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 1,\n)\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)\n\n| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |\n|-|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|\n|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|\n|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|\n|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|\n|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|\n|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|\n|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|\n|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|\n|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|\n|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|\n|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|\n|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|\n|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|\n|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|\n\n</details>\n\n### Video Synthesis\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\n#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.\n\n```python\nimport torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\n\nprompt = \"A girl is very happy, she is speaking: \\\"I enjoy working with Diffsynth-Studio, it's a perfect framework.\\\"\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)\n\n| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |\n|-|-|-|-|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|\n\n</details>\n\n#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)\n\n<details>\n\n<summary>Quick Start</summary>\n\nRunning the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n<details>\n\n<summary>Examples</summary>\n\nExample code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)\n\n| Model ID | Extra Inputs | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n|-|-|-|-|-|-|-|-|\n|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|\n|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|\n|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|\n|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|\n|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|\n|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|\n|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|\n|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|\n|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|\n|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|\n|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|\n|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|\n|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|\n|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|\n|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|\n|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|\n|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|\n|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|\n|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|\n|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|\n|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|\n|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|\n|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|\n|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|\n|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|\n|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|\n|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|\n\n</details>\n\n## Innovative Achievements\n\nDiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.\n\n<details>\n\n<summary>Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation</summary>\n\n- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation\n](https://arxiv.org/abs/2602.03208)\n- Sample Code: [/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)\n\n|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|\n|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|\n\n</details>\n\n\n<details>\n\n<summary>VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers</summary>\n\n- Paper: [VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers\n](https://arxiv.org/abs/2602.03210)\n- Sample code: [/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)\n\n|Example 1|Example 2|Query|Output|\n|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|\n\n</details>\n\n\n<details>\n\n<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>\n\n- Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151)\n- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)\n\n|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|\n|-|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|\n\n</details>\n\n\n<details>\n\n<summary>AutoLoRA: Automated LoRA Retrieval and Fusion</summary>\n\n- Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107)\n- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)\n\n||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|\n|-|-|-|-|-|\n|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)                              |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|\n|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)                       |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|\n|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)  |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|\n|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)                          |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|\n\n</details>\n\n\n<details>\n\n<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>\n\n- Detailed Page: https://github.com/modelscope/Nexus-Gen\n- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)\n- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)\n- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)\n\n![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)\n\n</details>\n\n\n<details>\n\n<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>\n\n- Detailed Page: [./examples/ArtAug/](./examples/ArtAug/)\n- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)\n- Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)\n\n|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|\n|-|-|\n|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|\n\n</details>\n\n\n<details>\n\n<summary>EliGen: Precise Image Partition Control</summary>\n\n- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)\n- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)\n- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)\n- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)\n- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)\n\n|Entity Control Region|Generated Image|\n|-|-|\n|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|\n\n</details>\n\n\n<details>\n\n<summary>ExVideo: Extended Training for Video Generation Models</summary>\n\n- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)\n- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)\n- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)\n- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)\n\nhttps://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc\n\n</details>\n\n\n<details>\n\n<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>\n\n- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)\n- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)\n- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)\n\nhttps://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd\n\n</details>\n\n\n<details>\n\n<summary>DiffSynth: The Original Version of This Project</summary>\n\n- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)\n- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)\n- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)\n\nhttps://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea\n\n</details>\n"
  },
  {
    "path": "README_zh.md",
    "content": "# DiffSynth-Studio\n\n<a href=\"https://github.com/modelscope/DiffSynth-Studio\"><img src=\".github/workflows/logo.gif\" title=\"Logo\" style=\"max-width:100%;\" width=\"55\" /></a> <a href=\"https://trendshift.io/repositories/10946\" target=\"_blank\"><img src=\"https://trendshift.io/api/badge/repositories/10946\" alt=\"modelscope%2FDiffSynth-Studio | Trendshift\" style=\"width: 250px; height: 55px;\" width=\"250\" height=\"55\"/></a></p>\n\n[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)\n[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)\n[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)\n[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)\n[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)\n\n[Switch to English](./README.md)\n\n## 简介\n\n> DiffSynth-Studio 文档：[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)\n\n欢迎来到 Diffusion 模型的魔法世界！DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新，凝聚开源社区的力量，探索生成式模型技术的边界！\n\nDiffSynth 目前包括两个开源项目：\n* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索，面向学术界，提供更前沿的模型能力支持。\n* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署，面向工业界，提供更高的计算性能与更稳定的功能。\n\n[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎，欢迎体验我们精心打造的产品化功能：\n\n* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home\n* ModelScope Civision (for global users): https://modelscope.ai/civision/home\n\n我们相信，一个完善的开源代码框架能够降低技术探索的门槛，我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想，借助 DiffSynth-Studio，你可以快速实现这些想法。为此，我们为开发者准备了详细的文档，我们希望通过这些文档，帮助开发者理解 Diffusion 模型的原理，更期待与你一同拓展技术的边界。\n\n## 更新历史\n\n> DiffSynth-Studio 经历了大版本更新，部分旧功能已停止维护，如需使用旧版功能，请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。\n\n> 目前本项目的开发人员有限，大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责，因此新功能的开发进展会比较缓慢，issue 的回复和解决速度有限，我们对此感到非常抱歉，请各位开发者理解。\n\n- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持，包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。\n\n- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持，模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting，框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。\n\n- **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型，这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能，新增了画笔控制的图层拆分能力。\n\n- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持，详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型，我们期待其后续的模型更新。\n\n<details>\n<summary>更多</summary>\n\n- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持，详见[文档](docs/zh/Model_Details/LTX-2.md)。\n\n- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持，详见[文档](docs/zh/Model_Details/LTX-2.md)，后续将推进模型训练的支持。\n\n- **2026年2月2日** Research Tutorial 的第一篇文档上线，带你从零开始训练一个 0.1B 的小型文生图模型，详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)，我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。\n\n- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布，我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布，在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验，详见[文档](/docs/zh/Model_Details/Z-Image.md)。\n\n- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持，包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。\n\n- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型（[模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)），这一模型输入一张图与一段文本描述，模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog（[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control)）。\n\n- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型（[模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)），这个模型可以输入三张图：图A、图B、图C，模型会自行分析图A到图B的变化，并将这样的变化应用到图C，生成图D。更多细节请阅读我们的 blog（[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)）。\n\n- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型：[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)（Image to LoRA）。这一模型以图像为输入，以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间，我们将这些模型开源，以启发更多创新性的研究工作。更多细节，请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。\n\n- **2025年12月4日** DiffSynth-Studio 2.0 发布！众多新功能上线\n  - [文档](/docs/zh/README.md)上线：我们的文档还在持续优化更新中\n  - [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级，支持 Layer 级别的 Disk Offload，同时释放内存与显存\n  - 新模型支持\n    - Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)\n    - FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)\n  - 训练框架升级\n    - [拆分训练](/docs/zh/Training/Split_Training.md)：支持自动化地将训练过程拆分为数据处理和训练两阶段（即使训练的是 ControlNet 或其他任意模型），在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算，在训练阶段处理其他计算。速度更快，显存需求更少。\n    - [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md)：这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术，目前已可用于任意模型的 LoRA 训练。\n    - [FP8 训练](/docs/zh/Training/FP8_Precision.md)：FP8 在训练中支持应用到任意非训练模型，即梯度关闭或者梯度仅影响 LoRA 权重的模型。\n\n- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型，该模型基于 Wan 2.1 训练，支持根据参考视频生成相应的动作。\n\n- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型，该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。\n\n- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型，Wan 模型生态再添一员。\n\n- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布！本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建，专为电商海报场景设计，支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。\n\n- **2025年9月9日** 我们的训练框架支持了多种训练模式，目前已适配 Qwen-Image，除标准 SFT 训练模式外，已支持 Direct Distill，请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的，我们将会继续完善已支持更全面的模型训练功能。\n\n- **2025年8月28日** 我们支持了Wan2.2-S2V，一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。\n\n- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布！相比于 V1 版本，训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset)，因此，生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。\n\n- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型，采用 In Context 的技术路线，支持多种类别的结构控制条件，包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。\n\n- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型，提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)\n\n- **2025年8月19日** 🔥 Qwen-Image-Edit 开源，欢迎图像编辑模型新成员！\n\n- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)，模型结构采用了轻量化的设计，请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。\n\n- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集，共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型，我们旨在通过开源推动技术发展！\n\n- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)，模型结构采用了轻量化的设计，请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。\n\n- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)，模型结构采用了轻量化的设计，请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。\n\n- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)，沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程，但模型结构修改为了 LoRA，因此能够更好地与其他开源生态模型兼容。\n\n- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集：[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。\n\n- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)，实现了约 5 倍加速。\n\n- **2025年8月4日** 🔥 Qwen-Image 开源，欢迎图像生成模型家族新成员！\n\n- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源，这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持，包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。\n\n- **2025年7月28日** Wan 2.2 开源，我们第一时间提供了全方位支持，包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。\n\n- **2025年7月11日** 我们提出 Nexus-Gen，一个将大语言模型（LLM）的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。\n  - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)\n  - Github 仓库: https://github.com/modelscope/Nexus-Gen\n  - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)\n  - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)\n  - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)\n\n- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。\n\n- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源！专注于稳定的模型部署，面向工业界，提供更好的工程支持、更高的计算性能和更稳定的功能。\n\n- **2025年3月31日** 我们支持 InfiniteYou，一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。\n\n- **2025年3月13日** 我们支持 HunyuanVideo-I2V，即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。\n\n- **2025年2月25日** 我们支持 Wan-Video，这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。\n\n- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)！先进的视频合成模型！详见 [./examples/stepvideo](./examples/stepvideo/)。\n\n- **2024年12月31日** 我们提出 EliGen，一种用于精确实体级别控制的文本到图像生成的新框架，并辅以修复融合管道，将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型，如 IP-Adapter 和 In-Context LoRA，提升其通用性。更多详情，请见 [./examples/EntityControl](./examples/EntityControl/)。\n  - 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)\n  - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)\n  - 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)\n  - 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)\n\n- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理，使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频，或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。\n\n- **2024年12月18日** 我们提出 ArtAug，一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev，从而提升了生成图像的质量。\n  - 论文: https://arxiv.org/abs/2412.12888\n  - 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug\n  - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)\n  - 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)\n\n- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型，并且可以自由组合，即使它们的结构不同。此外，ControlNet 模型兼容高分辨率优化和分区控制技术，能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。\n\n- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。\n\n- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能，包括：\n  - 文本到视频\n  - 视频编辑\n  - 自我超分\n  - 视频插帧\n\n- **2024年8月22日** 我们实现了一个有趣的画笔功能，支持所有文生图模型。现在，您可以在 AI 的辅助下使用画笔创作惊艳的图像了！\n  - 在我们的 [WebUI](#usage-in-webui) 中使用它。\n\n- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。\n  - 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)\n  - LoRA、ControlNet 和其他附加模型将很快推出。\n\n- **2024年6月21日** 我们提出 ExVideo，一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展，实现了长达 128 帧的长视频生成。\n  - [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)\n  - 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。\n  - 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。\n  - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。\n  - 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo！\n\n- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然，我仍会参与后续的开发和维护工作。\n\n- **2024年1月29日** 我们提出 Diffutoon，这是一个出色的卡通着色解决方案。\n  - [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)\n  - 源代码已在此项目中发布。\n  - 技术报告（IJCAI 2024）已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。\n\n- **2023年12月8日** 我们决定启动一个新项目，旨在释放扩散模型的潜力，尤其是在视频合成方面。该项目的开发工作正式开始。\n\n- **2023年11月15日** 我们提出 FastBlend，一种强大的视频去闪烁算法。\n  - sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。\n  - 演示视频已在 Bilibili 上展示，包含三个任务：\n    - [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)\n    - [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)\n    - [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)\n  - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。\n  - 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。\n\n- **2023年10月1日** 我们发布了该项目的早期版本，名为 FastSDXL。这是构建一个扩散引擎的初步尝试。\n  - 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。\n  - FastSDXL 包含一个可训练的 OLSS 调度器，以提高效率。\n    - OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。\n    - 技术报告（CIKM 2023）已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。\n    - 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。\n    - 由于 OLSS 需要额外训练，我们未在本项目中实现它。\n\n- **2023年8月29日** 我们提出 DiffSynth，一个视频合成框架。\n  - [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。\n  - 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。\n  - 技术报告（ECML PKDD 2024）已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。\n\n</details>\n\n## 安装\n\n从源码安装（推荐）：\n\n```\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多安装方式，以及非 NVIDIA GPU 的安装，请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。\n\n</details>\n\n## 基础框架\n\nDiffSynth-Studio 为主流 Diffusion 模型（包括 FLUX、Wan 等）重新设计了推理和训练流水线，能够实现高效的显存管理、灵活的模型训练。\n\n<details>\n<summary>环境变量配置</summary>\n\n> 在进行模型推理和训练前，可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。\n> \n> 本项目默认从魔搭社区下载模型。对于非中国区域的用户，可以通过以下配置从魔搭社区的国际站下载模型：\n> \n> ```python\n> import os\n> os.environ[\"MODELSCOPE_DOMAIN\"] = \"www.modelscope.ai\"\n> ```\n> \n> 如需从其他站点下载，请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。\n\n</details>\n\n### 图像生成模型\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\n#### Z-Image：[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化，因此不建议在 Z-Image Turbo 模型上开启任何量化，仅建议开启 CPU Offload，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nZ-Image 的示例代码位于：[/examples/z_image/](/examples/z_image/)\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|\n|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|\n|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|\n\n</details>\n\n#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 10G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nFLUX.2 的示例代码位于：[/examples/flux2/](/examples/flux2/)\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|\n|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|\n|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|\n|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|\n|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|\n\n</details>\n\n#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nAnima 的示例代码位于：[/examples/anima/](/examples/anima/)\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|\n\n</details>\n\n#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nQwen-Image 的示例代码位于：[/examples/qwen_image/](/examples/qwen_image/)\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|\n|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|\n|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|\n|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|\n|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|\n|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|\n|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|\n|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|\n|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|\n|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|\n\n</details>\n\n#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 1,\n)\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nFLUX.1 的示例代码位于：[/examples/flux/](/examples/flux/)\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|\n|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|\n|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|\n|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|\n|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|\n|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|\n|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|\n|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|\n|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|\n|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|\n|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|\n|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|\n|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|\n|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|\n\n</details>\n\n### 视频生成模型\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\n#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8GB 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\n\nprompt = \"A girl is very happy, she is speaking: \\\"I enjoy working with Diffsynth-Studio, it's a perfect framework.\\\"\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nLTX-2 的示例代码位于：[/examples/ltx2/](/examples/ltx2/)\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|\n\n</details>\n\n#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)\n\n<details>\n\n<summary>快速开始</summary>\n\n运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n<details>\n\n<summary>示例代码</summary>\n\nWan 的示例代码位于：[/examples/wanvideo/](/examples/wanvideo/)\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|\n|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|\n|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|\n|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|\n|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|\n|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|\n|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|\n|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|\n|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|\n|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|\n|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|\n|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|\n|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|\n|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|\n|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|\n|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|\n|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|\n|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|\n|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|\n|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|\n|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|\n|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|\n|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|\n|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|\n|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|\n|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|\n|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|\n\n</details>\n\n## 创新成果\n\nDiffSynth-Studio 不仅仅是一个工程化的模型框架，更是创新成果的孵化器。\n\n<details>\n\n<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>\n\n- 论文：[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation\n](https://arxiv.org/abs/2602.03208)\n- 代码样例：[/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)\n\n|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|\n|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|\n\n</details>\n\n\n<details>\n\n<summary>VIRAL：基于DiT模型的类比视觉上下文推理</summary>\n\n- 论文：[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers\n](https://arxiv.org/abs/2602.03210)\n- 代码样例：[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)\n\n|Example 1|Example 2|Query|Output|\n|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|\n\n</details>\n\n\n<details>\n\n<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>\n\n- 论文：[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models\n](https://arxiv.org/abs/2508.02151)\n- 代码样例：[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)\n\n|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|\n|-|-|-|-|-|\n|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|\n\n</details>\n\n\n<details>\n\n<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>\n\n- 论文：[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation\n](https://arxiv.org/abs/2508.02107)\n- 代码样例：[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)\n\n||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|\n|-|-|-|-|-|\n|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)                              |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|\n|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)                       |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|\n|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)  |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|\n|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)                          |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|\n\n</details>\n\n\n<details>\n\n<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>\n\n- 详细页面：https://github.com/modelscope/Nexus-Gen\n- 论文：[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)\n- 数据集：[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)\n- 在线体验：[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)\n\n![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)\n\n</details>\n\n\n<details>\n\n<summary>ArtAug: 图像生成模型的美学提升</summary>\n\n- 详细页面：[./examples/ArtAug/](./examples/ArtAug/)\n- 论文：[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)\n- 在线体验：[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)\n\n|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|\n|-|-|\n|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|\n\n</details>\n\n\n<details>\n\n<summary>EliGen: 精准的图像分区控制</summary>\n\n- 论文：[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)\n- 代码样例：[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)\n- 模型：[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)\n- 在线体验：[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)\n- 数据集：[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)\n\n|实体控制区域|生成图像|\n|-|-|\n|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|\n\n</details>\n\n\n<details>\n\n<summary>ExVideo: 视频生成模型的扩展训练</summary>\n\n- 项目页面：[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)\n- 论文：[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)\n- 代码样例：请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看\n- 模型：[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)\n\nhttps://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc\n\n</details>\n\n\n<details>\n\n<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>\n\n- 项目页面：[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)\n- 论文：[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)\n- 代码样例：请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看\n\nhttps://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd\n\n</details>\n\n\n<details>\n\n<summary>DiffSynth: 本项目的初代版本</summary>\n\n- 项目页面：[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)\n- 论文：[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)\n- 代码样例：请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看\n\nhttps://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea\n\n</details>\n"
  },
  {
    "path": "diffsynth/__init__.py",
    "content": "from .core import *\n"
  },
  {
    "path": "diffsynth/configs/__init__.py",
    "content": "from .model_configs import MODEL_CONFIGS\nfrom .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS\n"
  },
  {
    "path": "diffsynth/configs/model_configs.py",
    "content": "qwen_image_series = [\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\")\n        \"model_hash\": \"0319a1cb19835fb510907dd3367c95ff\",\n        \"model_name\": \"qwen_image_dit\",\n        \"model_class\": \"diffsynth.models.qwen_image_dit.QwenImageDiT\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"8004730443f55db63092006dd9f7110e\",\n        \"model_name\": \"qwen_image_text_encoder\",\n        \"model_class\": \"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"ed4ea5824d55ec3107b09815e318123a\",\n        \"model_name\": \"qwen_image_vae\",\n        \"model_class\": \"diffsynth.models.qwen_image_vae.QwenImageVAE\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"073bce9cf969e317e5662cd570c3e79c\",\n        \"model_name\": \"qwen_image_blockwise_controlnet\",\n        \"model_class\": \"diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"a9e54e480a628f0b956a688a81c33bab\",\n        \"model_name\": \"qwen_image_blockwise_controlnet\",\n        \"model_class\": \"diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet\",\n        \"extra_kwargs\": {\"additional_in_dim\": 4},\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\")\n        \"model_hash\": \"469c78b61e3e31bc9eec0d0af3d3f2f8\",\n        \"model_name\": \"siglip2_image_encoder\",\n        \"model_class\": \"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\")\n        \"model_hash\": \"5722b5c873720009de96422993b15682\",\n        \"model_name\": \"dinov3_image_encoder\",\n        \"model_class\": \"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder\",\n    },\n    {\n        # Example: \n        \"model_hash\": \"a166c33455cdbd89c0888a3645ca5c0f\",\n        \"model_name\": \"qwen_image_image2lora_coarse\",\n        \"model_class\": \"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel\",\n    },\n    {\n        # Example: \n        \"model_hash\": \"a5476e691767a4da6d3a6634a10f7408\",\n        \"model_name\": \"qwen_image_image2lora_fine\",\n        \"model_class\": \"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel\",\n        \"extra_kwargs\": {\"residual_length\": 37*37+7, \"residual_mid_dim\": 64}\n    },\n    {\n        # Example: \n        \"model_hash\": \"0aad514690602ecaff932c701cb4b0bb\",\n        \"model_name\": \"qwen_image_image2lora_style\",\n        \"model_class\": \"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel\",\n        \"extra_kwargs\": {\"compress_dim\": 64, \"use_residual\": False}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"8dc8cda05de16c73afa755e2c1ce2839\",\n        \"model_name\": \"qwen_image_dit\",\n        \"model_class\": \"diffsynth.models.qwen_image_dit.QwenImageDiT\",\n        \"extra_kwargs\": {\"use_layer3d_rope\": True, \"use_additional_t_cond\": True}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"44b39ddc499e027cfb24f7878d7416b9\",\n        \"model_name\": \"qwen_image_vae\",\n        \"model_class\": \"diffsynth.models.qwen_image_vae.QwenImageVAE\",\n        \"extra_kwargs\": {\"image_channels\": 4}\n    },\n]\n\nwan_series = [\n    {\n        # Example: ModelConfig(model_id=\"krea/krea-realtime-video\", origin_file_pattern=\"krea-realtime-video-14b.safetensors\")\n        \"model_hash\": \"5ec04e02b42d2580483ad69f4e76346a\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\")\n        \"model_hash\": \"9c8818c2cbea55eca56c7b447df170da\",\n        \"model_name\": \"wan_video_text_encoder\",\n        \"model_class\": \"diffsynth.models.wan_video_text_encoder.WanTextEncoder\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\")\n        \"model_hash\": \"ccc42284ea13e1ad04693284c7a09be6\",\n        \"model_name\": \"wan_video_vae\",\n        \"model_class\": \"diffsynth.models.wan_video_vae.WanVideoVAE\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"meituan-longcat/LongCat-Video\", origin_file_pattern=\"dit/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"8b27900f680d7251ce44e2dc8ae1ffef\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"5f90e66a0672219f12d9a626c8c21f61\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"5f90e66a0672219f12d9a626c8c21f61\",\n        \"model_name\": \"wan_video_vap\",\n        \"model_class\": \"diffsynth.models.wan_video_mot.MotWanModel\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\")\n        \"model_hash\": \"5941c53e207d62f20f9025686193c40b\",\n        \"model_name\": \"wan_video_image_encoder\",\n        \"model_class\": \"diffsynth.models.wan_video_image_encoder.WanImageEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"dbd5ec76bbf977983f972c151d545389\",\n        \"model_name\": \"wan_video_motion_controller\",\n        \"model_class\": \"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"9269f8db9040a9d860eaca435be61814\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"3ef3b1f8e1dab83d5b71fd7b617f859f\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"349723183fc063b2bfc10bb2835cf677\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"6d6ccde6845b95ad9114ab993d917893\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"efa44cddf936c70abd0ea28b6cbe946c\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"6bfcfb3b342cb286ce886889d519a77e\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"ac6a5aa74f4a0aab6f64eb9a72f19901\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"70ddad9d3a133785da5ea371aae09504\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"b61c605c2adbd23124d152ed28e049ae\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"26bde73488a92e64cc20b0a7485b9e5b\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"aafcfd9672c3a2456dc46e1cb6e52c70\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}\n    },\n    {\n        # Example: ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"a61453409b67cd3246cf0c3bebad47ba\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"a61453409b67cd3246cf0c3bebad47ba\",\n        \"model_name\": \"wan_video_vace\",\n        \"model_class\": \"diffsynth.models.wan_video_vace.VaceWanModel\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"7a513e1f257a861512b1afd387a8ecd9\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"7a513e1f257a861512b1afd387a8ecd9\",\n        \"model_name\": \"wan_video_vace\",\n        \"model_class\": \"diffsynth.models.wan_video_vace.VaceWanModel\",\n        \"extra_kwargs\": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"31fa352acb8a1b1d33cd8764273d80a2\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"31fa352acb8a1b1d33cd8764273d80a2\",\n        \"model_name\": \"wan_video_animate_adapter\",\n        \"model_class\": \"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter\"\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"47dbeab5e560db3180adf51dc0232fb1\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"2267d489f0ceb9f21836532952852ee5\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"5b013604280dd715f8457c6ed6d6a626\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"966cffdcc52f9c46c391768b27637614\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit_s2v.WanS2VModel\",\n        \"extra_kwargs\": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\")\n        \"model_hash\": \"1f5ab7703c6fc803fdded85ff040c316\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"Wan2.2_VAE.pth\")\n        \"model_hash\": \"e1de6c02cdac79f8b739f4d3698cd216\",\n        \"model_name\": \"wan_video_vae\",\n        \"model_class\": \"diffsynth.models.wan_video_vae.WanVideoVAE38\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\")\n        \"model_hash\": \"06be60f3a4526586d8431cd038a71486\",\n        \"model_name\": \"wans2v_audio_encoder\",\n        \"model_class\": \"diffsynth.models.wav2vec.WanS2VAudioEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"global_model.safetensors\")\n        \"model_hash\": \"eb18873fc0ba77b541eb7b62dbcd2059\",\n        \"model_name\": \"wan_video_dit\",\n        \"model_class\": \"diffsynth.models.wan_video_dit.WanModel\",\n        \"extra_kwargs\": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}\n    },\n]\n\nflux_series = [\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\")\n        \"model_hash\": \"a29710fea6dddb0314663ee823598e50\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n    },\n    {\n        # Supported due to historical reasons.\n        \"model_hash\": \"605c56eab23e9e2af863ad8f0813a25d\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\")\n        \"model_hash\": \"94eefa3dac9cec93cb1ebaf1747d7b78\",\n        \"model_name\": \"flux_text_encoder_clip\",\n        \"model_class\": \"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\")\n        \"model_hash\": \"22540b49eaedbc2f2784b2091a234c7c\",\n        \"model_name\": \"flux_text_encoder_t5\",\n        \"model_class\": \"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\")\n        \"model_hash\": \"21ea55f476dfc4fd135587abb59dfe5d\",\n        \"model_name\": \"flux_vae_encoder\",\n        \"model_class\": \"diffsynth.models.flux_vae.FluxVAEEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\")\n        \"model_hash\": \"21ea55f476dfc4fd135587abb59dfe5d\",\n        \"model_name\": \"flux_vae_decoder\",\n        \"model_class\": \"diffsynth.models.flux_vae.FluxVAEDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"ostris/Flex.2-preview\", origin_file_pattern=\"Flex.2-preview.safetensors\")\n        \"model_hash\": \"d02f41c13549fa5093d3521f62a5570a\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"extra_kwargs\": {'input_dim': 196, 'num_blocks': 8},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/AttriCtrl-FLUX.1-Dev\", origin_file_pattern=\"models/brightness.safetensors\")\n        \"model_hash\": \"0629116fce1472503a66992f96f3eb1a\",\n        \"model_name\": \"flux_value_controller\",\n        \"model_class\": \"diffsynth.models.flux_value_control.SingleValueEncoder\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"52357cb26250681367488a8954c271e8\",\n        \"model_name\": \"flux_controlnet\",\n        \"model_class\": \"diffsynth.models.flux_controlnet.FluxControlNet\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter\",\n        \"extra_kwargs\": {\"num_joint_blocks\": 6, \"num_single_blocks\": 0, \"additional_input_dim\": 4},\n    },\n    {\n        # Example: ModelConfig(model_id=\"InstantX/FLUX.1-dev-Controlnet-Union-alpha\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"78d18b9101345ff695f312e7e62538c0\",\n        \"model_name\": \"flux_controlnet\",\n        \"model_class\": \"diffsynth.models.flux_controlnet.FluxControlNet\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter\",\n        \"extra_kwargs\": {\"num_mode\": 10, \"mode_dict\": {\"canny\": 0, \"tile\": 1, \"depth\": 2, \"blur\": 3, \"pose\": 4, \"gray\": 5, \"lq\": 6}},\n    },\n    {\n        # Example: ModelConfig(model_id=\"jasperai/Flux.1-dev-Controlnet-Upscaler\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"b001c89139b5f053c715fe772362dd2a\",\n        \"model_name\": \"flux_controlnet\",\n        \"model_class\": \"diffsynth.models.flux_controlnet.FluxControlNet\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter\",\n        \"extra_kwargs\": {\"num_single_blocks\": 0},\n    },\n    {\n        # Example: ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/image_proj_model.bin\")\n        \"model_hash\": \"c07c0f04f5ff55e86b4e937c7a40d481\",\n        \"model_name\": \"infiniteyou_image_projector\",\n        \"model_class\": \"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\")\n        \"model_hash\": \"7f9583eb8ba86642abb9a21a4b2c9e16\",\n        \"model_name\": \"flux_controlnet\",\n        \"model_class\": \"diffsynth.models.flux_controlnet.FluxControlNet\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter\",\n        \"extra_kwargs\": {\"num_joint_blocks\": 4, \"num_single_blocks\": 10},\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"77c2e4dd2440269eb33bfaa0d004f6ab\",\n        \"model_name\": \"flux_lora_encoder\",\n        \"model_class\": \"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"30143afb2dea73d1ac580e0787628f8c\",\n        \"model_name\": \"flux_lora_patcher\",\n        \"model_class\": \"diffsynth.models.flux_lora_patcher.FluxLoraPatcher\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\")\n        \"model_hash\": \"2bd19e845116e4f875a0a048e27fc219\",\n        \"model_name\": \"nexus_gen_llm\",\n        \"model_class\": \"diffsynth.models.nexus_gen.NexusGenAutoregressiveModel\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\")\n        \"model_hash\": \"63c969fd37cce769a90aa781fbff5f81\",\n        \"model_name\": \"nexus_gen_editing_adapter\",\n        \"model_class\": \"diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\")\n        \"model_hash\": \"63c969fd37cce769a90aa781fbff5f81\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"generation_decoder.bin\")\n        \"model_hash\": \"3e6c61b0f9471135fc9c6d6a98e98b6d\",\n        \"model_name\": \"nexus_gen_generation_adapter\",\n        \"model_class\": \"diffsynth.models.nexus_gen_projector.NexusGenAdapter\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"generation_decoder.bin\")\n        \"model_hash\": \"3e6c61b0f9471135fc9c6d6a98e98b6d\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"InstantX/FLUX.1-dev-IP-Adapter\", origin_file_pattern=\"ip-adapter.bin\")\n        \"model_hash\": \"4daaa66cc656a8fe369908693dad0a35\",\n        \"model_name\": \"flux_ipadapter\",\n        \"model_class\": \"diffsynth.models.flux_ipadapter.FluxIpAdapter\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"google/siglip-so400m-patch14-384\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"04d8c1e20a1f1b25f7434f111992a33f\",\n        \"model_name\": \"siglip_vision_model\",\n        \"model_class\": \"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\"),\n        \"model_hash\": \"d30fb9e02b1dbf4e509142f05cf7dd50\",\n        \"model_name\": \"step1x_connector\",\n        \"model_class\": \"diffsynth.models.step1x_connector.Qwen2Connector\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\"),\n        \"model_hash\": \"d30fb9e02b1dbf4e509142f05cf7dd50\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n        \"extra_kwargs\": {\"disable_guidance_embedder\": True},\n    },\n    {\n        # Example: ModelConfig(model_id=\"MAILAND/majicflus_v1\", origin_file_pattern=\"majicflus_v134.safetensors\")\n        \"model_hash\": \"3394f306c4cbf04334b712bf5aaed95f\",\n        \"model_name\": \"flux_dit\",\n        \"model_class\": \"diffsynth.models.flux_dit.FluxDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter\",\n    },\n]\n\nflux2_series = [\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\")\n        \"model_hash\": \"28fca3d8e5bf2a2d1271748a773f6757\",\n        \"model_name\": \"flux2_text_encoder\",\n        \"model_class\": \"diffsynth.models.flux2_text_encoder.Flux2TextEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\")\n        \"model_hash\": \"d38e1d5c5aec3b0a11e79327ac6e3b0f\",\n        \"model_name\": \"flux2_dit\",\n        \"model_class\": \"diffsynth.models.flux2_dit.Flux2DiT\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"c54288e3ee12ca215898840682337b95\",\n        \"model_name\": \"flux2_vae\",\n        \"model_class\": \"diffsynth.models.flux2_vae.Flux2VAE\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\")\n        \"model_hash\": \"3bde7b817fec8143028b6825a63180df\",\n        \"model_name\": \"flux2_dit\",\n        \"model_class\": \"diffsynth.models.flux2_dit.Flux2DiT\",\n        \"extra_kwargs\": {\"guidance_embeds\": False, \"joint_attention_dim\": 7680, \"num_attention_heads\": 24, \"num_layers\": 5, \"num_single_layers\": 20}\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\")\n        \"model_hash\": \"9195f3ea256fcd0ae6d929c203470754\",\n        \"model_name\": \"z_image_text_encoder\",\n        \"model_class\": \"diffsynth.models.z_image_text_encoder.ZImageTextEncoder\",\n        \"extra_kwargs\": {\"model_size\": \"8B\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"transformer/*.safetensors\")\n        \"model_hash\": \"39c6fc48f07bebecedbbaa971ff466c8\",\n        \"model_name\": \"flux2_dit\",\n        \"model_class\": \"diffsynth.models.flux2_dit.Flux2DiT\",\n        \"extra_kwargs\": {\"guidance_embeds\": False, \"joint_attention_dim\": 12288, \"num_attention_heads\": 32, \"num_layers\": 8, \"num_single_layers\": 24}\n    },\n]\n\nz_image_series = [\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\")\n        \"model_hash\": \"fc3a8a1247fe185ce116ccbe0e426c28\",\n        \"model_name\": \"z_image_dit\",\n        \"model_class\": \"diffsynth.models.z_image_dit.ZImageDiT\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\")\n        \"model_hash\": \"0f050f62a88876fea6eae0a18dac5a2e\",\n        \"model_name\": \"z_image_text_encoder\",\n        \"model_class\": \"diffsynth.models.z_image_text_encoder.ZImageTextEncoder\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/vae/diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"1aafa3cc91716fb6b300cc1cd51b85a3\",\n        \"model_name\": \"flux_vae_encoder\",\n        \"model_class\": \"diffsynth.models.flux_vae.FluxVAEEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers\",\n        \"extra_kwargs\": {\"use_conv_attention\": False},\n    },\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/vae/diffusion_pytorch_model.safetensors\")\n        \"model_hash\": \"1aafa3cc91716fb6b300cc1cd51b85a3\",\n        \"model_name\": \"flux_vae_decoder\",\n        \"model_class\": \"diffsynth.models.flux_vae.FluxVAEDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers\",\n        \"extra_kwargs\": {\"use_conv_attention\": False},\n    },\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\")\n        \"model_hash\": \"aa3563718e5c3ecde3dfbb020ca61180\",\n        \"model_name\": \"z_image_dit\",\n        \"model_class\": \"diffsynth.models.z_image_dit.ZImageDiT\",\n        \"extra_kwargs\": {\"siglip_feat_dim\": 1152},\n    },\n    {\n        # Example: ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\")\n        \"model_hash\": \"89d48e420f45cff95115a9f3e698d44a\",\n        \"model_name\": \"siglip_vision_model_428m\",\n        \"model_class\": \"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors\")\n        \"model_hash\": \"1677708d40029ab380a95f6c731a57d7\",\n        \"model_name\": \"z_image_controlnet\",\n        \"model_class\": \"diffsynth.models.z_image_controlnet.ZImageControlNet\",\n    },\n    {\n        # Example: ???\n        \"model_hash\": \"9510cb8cd1dd34ee0e4f111c24905510\",\n        \"model_name\": \"z_image_image2lora_style\",\n        \"model_class\": \"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel\",\n        \"extra_kwargs\": {\"compress_dim\": 128},\n    },\n    {\n        # Example: ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\")\n        \"model_hash\": \"1392adecee344136041e70553f875f31\",\n        \"model_name\": \"z_image_text_encoder\",\n        \"model_class\": \"diffsynth.models.z_image_text_encoder.ZImageTextEncoder\",\n        \"extra_kwargs\": {\"model_size\": \"0.6B\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter\",\n    },\n]\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\nltx2_series = [\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_dit\",\n        \"model_class\": \"diffsynth.models.ltx2_dit.LTXModel\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\")\n        \"model_hash\": \"c567aaa37d5ed7454c73aa6024458661\",\n        \"model_name\": \"ltx2_dit\",\n        \"model_class\": \"diffsynth.models.ltx2_dit.LTXModel\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_video_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\")\n        \"model_hash\": \"7f7e904a53260ec0351b05f32153754b\",\n        \"model_name\": \"ltx2_video_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_video_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\")\n        \"model_hash\": \"dc6029ca2825147872b45e35a2dc3a97\",\n        \"model_name\": \"ltx2_video_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_audio_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\")\n        \"model_hash\": \"7d7823dde8f1ea0b50fb07ac329dd4cb\",\n        \"model_name\": \"ltx2_audio_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_audio_vocoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2Vocoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\")\n        \"model_hash\": \"f471360f6b24bef702ab73133d9f8bb9\",\n        \"model_name\": \"ltx2_audio_vocoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2Vocoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_audio_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_encoder.safetensors\")\n        \"model_hash\": \"29338f3b95e7e312a3460a482e4f4554\",\n        \"model_name\": \"ltx2_audio_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"aca7b0bbf8415e9c98360750268915fc\",\n        \"model_name\": \"ltx2_text_encoder_post_modules\",\n        \"model_class\": \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\")\n        \"model_hash\": \"981629689c8be92a712ab3c5eb4fc3f6\",\n        \"model_name\": \"ltx2_text_encoder_post_modules\",\n        \"model_class\": \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\")\n        \"model_hash\": \"33917f31c4a79196171154cca39f165e\",\n        \"model_name\": \"ltx2_text_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\")\n        \"model_hash\": \"c79c458c6e99e0e14d47e676761732d2\",\n        \"model_name\": \"ltx2_latent_upsampler\",\n        \"model_class\": \"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_dit\",\n        \"model_class\": \"diffsynth.models.ltx2_dit.LTXModel\",\n        \"extra_kwargs\": {\"apply_gated_attention\": True, \"cross_attention_adaln\": True, \"caption_channels\": None},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_video_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder\",\n        \"extra_kwargs\": {\"encoder_version\": \"ltx-2.3\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_video_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder\",\n        \"extra_kwargs\": {\"decoder_version\": \"ltx-2.3\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_audio_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_audio_vocoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_audio_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\")\n        \"model_hash\": \"f3a83ecf3995dcc4fae2d27e08ad5767\",\n        \"model_name\": \"ltx2_text_encoder_post_modules\",\n        \"model_class\": \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules\",\n        \"extra_kwargs\": {\"separated_audio_video\": True, \"embedding_dim_gemma\": 3840, \"num_layers_gemma\": 49, \"video_attention_heads\": 32, \"video_attention_head_dim\": 128, \"audio_attention_heads\": 32, \"audio_attention_head_dim\": 64, \"num_connector_layers\": 8, \"apply_gated_attention\": True},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\")\n        \"model_hash\": \"aed408774d694a2452f69936c32febb5\",\n        \"model_name\": \"ltx2_latent_upsampler\",\n        \"model_class\": \"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler\",\n        \"extra_kwargs\": {\"rational_resampler\": False},\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"transformer.safetensors\")\n        \"model_hash\": \"1c55afad76ed33c112a2978550b524d1\",\n        \"model_name\": \"ltx2_dit\",\n        \"model_class\": \"diffsynth.models.ltx2_dit.LTXModel\",\n        \"extra_kwargs\": {\"apply_gated_attention\": True, \"cross_attention_adaln\": True, \"caption_channels\": None},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\")\n        \"model_hash\": \"eecdc07c2ec30863b8a2b8b2134036cf\",\n        \"model_name\": \"ltx2_video_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder\",\n        \"extra_kwargs\": {\"encoder_version\": \"ltx-2.3\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\")\n        \"model_hash\": \"deda2f542e17ee25bc8c38fd605316ea\",\n        \"model_name\": \"ltx2_video_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder\",\n        \"extra_kwargs\": {\"decoder_version\": \"ltx-2.3\"},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\")\n        \"model_hash\": \"7d7823dde8f1ea0b50fb07ac329dd4cb\",\n        \"model_name\": \"ltx2_audio_vae_decoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_encoder.safetensors\")\n        \"model_hash\": \"29338f3b95e7e312a3460a482e4f4554\",\n        \"model_name\": \"ltx2_audio_vae_encoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\")\n        \"model_hash\": \"cd436c99e69ec5c80f050f0944f02a15\",\n        \"model_name\": \"ltx2_audio_vocoder\",\n        \"model_class\": \"diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter\",\n    },\n    {\n        # Example: ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\")\n        \"model_hash\": \"05da2aab1c4b061f72c426311c165a43\",\n        \"model_name\": \"ltx2_text_encoder_post_modules\",\n        \"model_class\": \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules\",\n        \"extra_kwargs\": {\"separated_audio_video\": True, \"embedding_dim_gemma\": 3840, \"num_layers_gemma\": 49, \"video_attention_heads\": 32, \"video_attention_head_dim\": 128, \"audio_attention_heads\": 32, \"audio_attention_head_dim\": 64, \"num_connector_layers\": 8, \"apply_gated_attention\": True},\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter\",\n    },\n]\nanima_series = [\n    {\n        # Example: ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\")\n        \"model_hash\": \"a9995952c2d8e63cf82e115005eb61b9\",\n        \"model_name\": \"z_image_text_encoder\",\n        \"model_class\": \"diffsynth.models.z_image_text_encoder.ZImageTextEncoder\",\n        \"extra_kwargs\": {\"model_size\": \"0.6B\"},\n    },\n    {\n        # Example: ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\")\n        \"model_hash\": \"417673936471e79e31ed4d186d7a3f4a\",\n        \"model_name\": \"anima_dit\",\n        \"model_class\": \"diffsynth.models.anima_dit.AnimaDiT\",\n        \"state_dict_converter\": \"diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter\",\n    }\n]\n\nmova_series = [\n    # Example: ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\")\n    {\n        \"model_hash\": \"8c57e12790e2c45a64817e0ce28cde2f\",\n        \"model_name\": \"mova_audio_dit\",\n        \"model_class\": \"diffsynth.models.mova_audio_dit.MovaAudioDit\",\n        \"extra_kwargs\": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}\n    },\n    # Example: ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\")\n    {\n        \"model_hash\": \"418517fb2b4e919d2cac8f314fcf82ac\",\n        \"model_name\": \"mova_audio_vae\",\n        \"model_class\": \"diffsynth.models.mova_audio_vae.DacVAE\",\n    },\n    # Example: ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\")\n    {\n        \"model_hash\": \"d1139dbbc8b4ab53cf4b4243d57bbceb\",\n        \"model_name\": \"mova_dual_tower_bridge\",\n        \"model_class\": \"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge\",\n    },\n]\nMODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series\n"
  },
  {
    "path": "diffsynth/configs/vram_management_module_maps.py",
    "content": "flux_general_vram_config = {\n    \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.GroupNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"diffsynth.models.general_modules.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"diffsynth.models.flux_lora_encoder.LoRALayerBlock\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"diffsynth.models.flux_lora_patcher.LoraMerger\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n}\n\nVRAM_MANAGEMENT_MODULE_MAPS = {\n    \"diffsynth.models.qwen_image_dit.QwenImageDiT\": {\n        \"diffsynth.models.qwen_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.qwen_image_vae.QwenImageVAE\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.qwen_image_vae.QwenImageRMS_norm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock\": {\n        \"diffsynth.models.qwen_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder\": {\n        \"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder\": {\n        \"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter\": {\n        \"diffsynth.models.wan_video_animate_adapter.FaceEncoder\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_animate_adapter.EqualLinear\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_animate_adapter.ConvLayer\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_animate_adapter.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_dit_s2v.WanS2VModel\": {\n        \"diffsynth.models.wan_video_dit.Head\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_dit.WanModel\": {\n        \"diffsynth.models.wan_video_dit.MLP\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.DiTBlock\": \"diffsynth.core.vram.layers.AutoWrappedNonRecurseModule\",\n        \"diffsynth.models.wan_video_dit.Head\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_image_encoder.WanImageEncoder\": {\n        \"diffsynth.models.wan_video_image_encoder.VisionTransformer\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_mot.MotWanModel\": {\n        \"diffsynth.models.wan_video_mot.MotWanAttentionBlock\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.wan_video_text_encoder.WanTextEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_text_encoder.T5LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_vace.VaceWanModel\": {\n        \"diffsynth.models.wan_video_dit.DiTBlock\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_vae.WanVideoVAE\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.RMS_norm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.CausalConv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.Upsample\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.SiLU\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Dropout\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wan_video_vae.WanVideoVAE38\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.RMS_norm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.CausalConv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_vae.Upsample\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.SiLU\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Dropout\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.wav2vec.WanS2VAudioEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.longcat_video_dit.RMSNorm_FP32\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.longcat_video_dit.LayerNorm_FP32\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux_dit.FluxDiT\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"diffsynth.models.flux_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip\": flux_general_vram_config,\n    \"diffsynth.models.flux_vae.FluxVAEEncoder\": flux_general_vram_config,\n    \"diffsynth.models.flux_vae.FluxVAEDecoder\": flux_general_vram_config,\n    \"diffsynth.models.flux_controlnet.FluxControlNet\": flux_general_vram_config,\n    \"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector\": flux_general_vram_config,\n    \"diffsynth.models.flux_ipadapter.FluxIpAdapter\": flux_general_vram_config,\n    \"diffsynth.models.flux_lora_patcher.FluxLoraPatcher\": flux_general_vram_config,\n    \"diffsynth.models.step1x_connector.Qwen2Connector\": flux_general_vram_config,\n    \"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder\": flux_general_vram_config,\n    \"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.t5.modeling_t5.T5LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.t5.modeling_t5.T5DenseActDense\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.t5.modeling_t5.T5DenseGatedActDense\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M\": {\n        \"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.siglip.modeling_siglip.SiglipEncoder\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.MultiheadAttention\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux2_dit.Flux2DiT\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux2_text_encoder.Flux2TextEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.mistral.modeling_mistral.MistralRMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.flux2_vae.Flux2VAE\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.GroupNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.z_image_text_encoder.ZImageTextEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.z_image_dit.ZImageDiT\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"diffsynth.models.z_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.z_image_controlnet.ZImageControlNet\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"diffsynth.models.z_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M\": {\n        \"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n    },\n    \"diffsynth.models.ltx2_dit.LTXModel\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler\": {\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.GroupNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder\": {\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder\": {\n        \"torch.nn.Conv3d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder\": {\n        \"torch.nn.Conv2d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_audio_vae.LTX2Vocoder\": {\n        \"torch.nn.Conv1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.ConvTranspose1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.anima_dit.AnimaDiT\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Embedding\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.mova_audio_dit.MovaAudioDit\": {\n        \"diffsynth.models.wan_video_dit.DiTBlock\": \"diffsynth.core.vram.layers.AutoWrappedNonRecurseModule\",\n        \"diffsynth.models.wan_video_dit.Head\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.Conv1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge\": {\n        \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n        \"torch.nn.LayerNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"diffsynth.models.wan_video_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n    \"diffsynth.models.mova_audio_vae.DacVAE\": {\n        \"diffsynth.models.mova_audio_vae.Snake1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.Conv1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n        \"torch.nn.ConvTranspose1d\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    },\n}\n\ndef QwenImageTextEncoder_Module_Map_Updater():\n    current = VRAM_MANAGEMENT_MODULE_MAPS[\"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\"]\n    from packaging import version\n    import transformers\n    if version.parse(transformers.__version__) >= version.parse(\"5.2.0\"):\n        # The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly\n        current.pop(\"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm\", None)\n        current[\"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm\"] = \"diffsynth.core.vram.layers.AutoWrappedModule\"\n    return current\n\nVERSION_CHECKER_MAPS = {\n    \"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\": QwenImageTextEncoder_Module_Map_Updater,\n}"
  },
  {
    "path": "diffsynth/core/__init__.py",
    "content": "from .attention import *\nfrom .data import *\nfrom .gradient import *\nfrom .loader import *\nfrom .vram import *\nfrom .device import *\n"
  },
  {
    "path": "diffsynth/core/attention/__init__.py",
    "content": "from .attention import attention_forward\n"
  },
  {
    "path": "diffsynth/core/attention/attention.py",
    "content": "import torch, os\nfrom einops import rearrange\n\n\ntry:\n    import flash_attn_interface\n    FLASH_ATTN_3_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_3_AVAILABLE = False\n\ntry:\n    import flash_attn\n    FLASH_ATTN_2_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_2_AVAILABLE = False\n\ntry:\n    from sageattention import sageattn\n    SAGE_ATTN_AVAILABLE = True\nexcept ModuleNotFoundError:\n    SAGE_ATTN_AVAILABLE = False\n\ntry:\n    import xformers.ops as xops\n    XFORMERS_AVAILABLE = True\nexcept ModuleNotFoundError:\n    XFORMERS_AVAILABLE = False\n\n\ndef initialize_attention_priority():\n    if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:\n        return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()\n    elif FLASH_ATTN_3_AVAILABLE:\n        return \"flash_attention_3\"\n    elif FLASH_ATTN_2_AVAILABLE:\n        return \"flash_attention_2\"\n    elif SAGE_ATTN_AVAILABLE:\n        return \"sage_attention\"\n    elif XFORMERS_AVAILABLE:\n        return \"xformers\"\n    else:\n        return \"torch\"\n\n\nATTENTION_IMPLEMENTATION = initialize_attention_priority()\n\n\ndef rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", required_in_pattern=\"b n s d\", dims=None):\n    dims = {} if dims is None else dims\n    if q_pattern != required_in_pattern:\n        q = rearrange(q, f\"{q_pattern} -> {required_in_pattern}\", **dims)\n    if k_pattern != required_in_pattern:\n        k = rearrange(k, f\"{k_pattern} -> {required_in_pattern}\", **dims)\n    if v_pattern != required_in_pattern:\n        v = rearrange(v, f\"{v_pattern} -> {required_in_pattern}\", **dims)\n    return q, k, v\n\n\ndef rearrange_out(out: torch.Tensor, out_pattern=\"b n s d\", required_out_pattern=\"b n s d\", dims=None):\n    dims = {} if dims is None else dims\n    if out_pattern != required_out_pattern:\n        out = rearrange(out, f\"{required_out_pattern} -> {out_pattern}\", **dims)\n    return out\n\n\ndef torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, attn_mask=None, scale=None):\n    required_in_pattern, required_out_pattern= \"b n s d\", \"b n s d\"\n    q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)\n    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)\n    out = rearrange_out(out, out_pattern, required_out_pattern, dims)\n    return out\n\n\ndef flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, scale=None):\n    required_in_pattern, required_out_pattern= \"b s n d\", \"b s n d\"\n    q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)\n    out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)\n    if isinstance(out, tuple):\n        out = out[0]\n    out = rearrange_out(out, out_pattern, required_out_pattern, dims)\n    return out\n\n\ndef flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, scale=None):\n    required_in_pattern, required_out_pattern= \"b s n d\", \"b s n d\"\n    q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)\n    out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)\n    out = rearrange_out(out, out_pattern, required_out_pattern, dims)\n    return out\n\n\ndef sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, scale=None):\n    required_in_pattern, required_out_pattern= \"b n s d\", \"b n s d\"\n    q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)\n    out = sageattn(q, k, v, sm_scale=scale)\n    out = rearrange_out(out, out_pattern, required_out_pattern, dims)\n    return out\n\n\ndef xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, scale=None):\n    required_in_pattern, required_out_pattern= \"b s n d\", \"b s n d\"\n    q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)\n    out = xops.memory_efficient_attention(q, k, v, scale=scale)\n    out = rearrange_out(out, out_pattern, required_out_pattern, dims)\n    return out\n\n\ndef attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=\"b n s d\", k_pattern=\"b n s d\", v_pattern=\"b n s d\", out_pattern=\"b n s d\", dims=None, attn_mask=None, scale=None, compatibility_mode=False):\n    if compatibility_mode or (attn_mask is not None):\n        return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)\n    else:\n        if ATTENTION_IMPLEMENTATION == \"flash_attention_3\":\n            return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)\n        elif ATTENTION_IMPLEMENTATION == \"flash_attention_2\":\n            return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)\n        elif ATTENTION_IMPLEMENTATION == \"sage_attention\":\n            return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)\n        elif ATTENTION_IMPLEMENTATION == \"xformers\":\n            return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)\n        else:\n            return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)\n"
  },
  {
    "path": "diffsynth/core/data/__init__.py",
    "content": "from .unified_dataset import UnifiedDataset\n"
  },
  {
    "path": "diffsynth/core/data/operators.py",
    "content": "import math\nimport torch, torchvision, imageio, os\nimport imageio.v3 as iio\nfrom PIL import Image\nimport torchaudio\n\n\nclass DataProcessingPipeline:\n    def __init__(self, operators=None):\n        self.operators: list[DataProcessingOperator] = [] if operators is None else operators\n        \n    def __call__(self, data):\n        for operator in self.operators:\n            data = operator(data)\n        return data\n    \n    def __rshift__(self, pipe):\n        if isinstance(pipe, DataProcessingOperator):\n            pipe = DataProcessingPipeline([pipe])\n        return DataProcessingPipeline(self.operators + pipe.operators)\n\n\nclass DataProcessingOperator:\n    def __call__(self, data):\n        raise NotImplementedError(\"DataProcessingOperator cannot be called directly.\")\n    \n    def __rshift__(self, pipe):\n        if isinstance(pipe, DataProcessingOperator):\n            pipe = DataProcessingPipeline([pipe])\n        return DataProcessingPipeline([self]).__rshift__(pipe)\n\n\nclass DataProcessingOperatorRaw(DataProcessingOperator):\n    def __call__(self, data):\n        return data\n\n\nclass ToInt(DataProcessingOperator):\n    def __call__(self, data):\n        return int(data)\n\n\nclass ToFloat(DataProcessingOperator):\n    def __call__(self, data):\n        return float(data)\n\n\nclass ToStr(DataProcessingOperator):\n    def __init__(self, none_value=\"\"):\n        self.none_value = none_value\n    \n    def __call__(self, data):\n        if data is None: data = self.none_value\n        return str(data)\n\n\nclass LoadImage(DataProcessingOperator):\n    def __init__(self, convert_RGB=True, convert_RGBA=False):\n        self.convert_RGB = convert_RGB\n        self.convert_RGBA = convert_RGBA\n    \n    def __call__(self, data: str):\n        image = Image.open(data)\n        if self.convert_RGB: image = image.convert(\"RGB\")\n        if self.convert_RGBA: image = image.convert(\"RGBA\")\n        return image\n\n\nclass ImageCropAndResize(DataProcessingOperator):\n    def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):\n        self.height = height\n        self.width = width\n        self.max_pixels = max_pixels\n        self.height_division_factor = height_division_factor\n        self.width_division_factor = width_division_factor\n\n    def crop_and_resize(self, image, target_height, target_width):\n        width, height = image.size\n        scale = max(target_width / width, target_height / height)\n        image = torchvision.transforms.functional.resize(\n            image,\n            (round(height*scale), round(width*scale)),\n            interpolation=torchvision.transforms.InterpolationMode.BILINEAR\n        )\n        image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))\n        return image\n    \n    def get_height_width(self, image):\n        if self.height is None or self.width is None:\n            width, height = image.size\n            if width * height > self.max_pixels:\n                scale = (width * height / self.max_pixels) ** 0.5\n                height, width = int(height / scale), int(width / scale)\n            height = height // self.height_division_factor * self.height_division_factor\n            width = width // self.width_division_factor * self.width_division_factor\n        else:\n            height, width = self.height, self.width\n        return height, width\n    \n    def __call__(self, data: Image.Image):\n        image = self.crop_and_resize(data, *self.get_height_width(data))\n        return image\n\n\nclass ToList(DataProcessingOperator):\n    def __call__(self, data):\n        return [data]\n    \n\nclass FrameSamplerByRateMixin:\n    def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):\n        self.num_frames = num_frames\n        self.time_division_factor = time_division_factor\n        self.time_division_remainder = time_division_remainder\n        self.frame_rate = frame_rate\n        self.fix_frame_rate = fix_frame_rate\n\n    def get_reader(self, data: str):\n        return imageio.get_reader(data)\n\n    def get_available_num_frames(self, reader):\n        if not self.fix_frame_rate:\n            return reader.count_frames()\n        meta_data = reader.get_meta_data()\n        total_original_frames = int(reader.count_frames())\n        duration = meta_data[\"duration\"] if \"duration\" in meta_data else total_original_frames / meta_data['fps']\n        total_available_frames = math.floor(duration * self.frame_rate)\n        return int(total_available_frames)\n\n    def get_num_frames(self, reader):\n        num_frames = self.num_frames\n        total_frames = self.get_available_num_frames(reader)\n        if int(total_frames) < num_frames:\n            num_frames = total_frames\n            while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:\n                num_frames -= 1\n        return num_frames\n\n    def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:\n        if not self.fix_frame_rate:\n            return new_sequence_id\n        target_time_in_seconds = new_sequence_id / self.frame_rate\n        raw_frame_index_float = target_time_in_seconds * raw_frame_rate\n        frame_id = int(round(raw_frame_index_float))        \n        frame_id = min(frame_id, total_raw_frames - 1)\n        return frame_id\n\n\nclass LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):\n    def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):\n        FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)\n        # frame_processor is build in the video loader for high efficiency.\n        self.frame_processor = frame_processor\n\n    def __call__(self, data: str):\n        reader = self.get_reader(data)\n        raw_frame_rate = reader.get_meta_data()['fps']\n        num_frames = self.get_num_frames(reader)\n        total_raw_frames = reader.count_frames()\n        frames = []\n        for frame_id in range(num_frames):\n            frame_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)\n            frame = reader.get_data(frame_id)\n            frame = Image.fromarray(frame)\n            frame = self.frame_processor(frame)\n            frames.append(frame)\n        reader.close()\n        return frames\n\n\nclass SequencialProcess(DataProcessingOperator):\n    def __init__(self, operator=lambda x: x):\n        self.operator = operator\n        \n    def __call__(self, data):\n        return [self.operator(i) for i in data]\n\n\nclass LoadGIF(DataProcessingOperator):\n    def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):\n        self.num_frames = num_frames\n        self.time_division_factor = time_division_factor\n        self.time_division_remainder = time_division_remainder\n        # frame_processor is build in the video loader for high efficiency.\n        self.frame_processor = frame_processor\n\n    def get_num_frames(self, path):\n        num_frames = self.num_frames\n        images = iio.imread(path, mode=\"RGB\")\n        if len(images) < num_frames:\n            num_frames = len(images)\n            while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:\n                num_frames -= 1\n        return num_frames\n        \n    def __call__(self, data: str):\n        num_frames = self.get_num_frames(data)\n        frames = []\n        images = iio.imread(data, mode=\"RGB\")\n        for img in images:\n            frame = Image.fromarray(img)\n            frame = self.frame_processor(frame)\n            frames.append(frame)\n            if len(frames) >= num_frames:\n                break\n        return frames\n\n\nclass RouteByExtensionName(DataProcessingOperator):\n    def __init__(self, operator_map):\n        self.operator_map = operator_map\n        \n    def __call__(self, data: str):\n        file_ext_name = data.split(\".\")[-1].lower()\n        for ext_names, operator in self.operator_map:\n            if ext_names is None or file_ext_name in ext_names:\n                return operator(data)\n        raise ValueError(f\"Unsupported file: {data}\")\n\n\nclass RouteByType(DataProcessingOperator):\n    def __init__(self, operator_map):\n        self.operator_map = operator_map\n        \n    def __call__(self, data):\n        for dtype, operator in self.operator_map:\n            if dtype is None or isinstance(data, dtype):\n                return operator(data)\n        raise ValueError(f\"Unsupported data: {data}\")\n\n\nclass LoadTorchPickle(DataProcessingOperator):\n    def __init__(self, map_location=\"cpu\"):\n        self.map_location = map_location\n        \n    def __call__(self, data):\n        return torch.load(data, map_location=self.map_location, weights_only=False)\n\n\nclass ToAbsolutePath(DataProcessingOperator):\n    def __init__(self, base_path=\"\"):\n        self.base_path = base_path\n        \n    def __call__(self, data):\n        return os.path.join(self.base_path, data)\n\n\nclass LoadAudio(DataProcessingOperator):\n    def __init__(self, sr=16000):\n        self.sr = sr\n    def __call__(self, data: str):\n        import librosa\n        input_audio, sample_rate = librosa.load(data, sr=self.sr)\n        return input_audio\n\n\nclass LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):\n\n    def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):\n        FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)\n\n    def __call__(self, data: str):\n        reader = self.get_reader(data)\n        num_frames = self.get_num_frames(reader)\n        duration = num_frames / self.frame_rate\n        waveform, sample_rate = torchaudio.load(data)\n        target_samples = int(duration * sample_rate)\n        current_samples = waveform.shape[-1]\n        if current_samples > target_samples:\n            waveform = waveform[..., :target_samples]\n        elif current_samples < target_samples:\n            padding = target_samples - current_samples\n            waveform = torch.nn.functional.pad(waveform, (0, padding))\n        return waveform, sample_rate\n"
  },
  {
    "path": "diffsynth/core/data/unified_dataset.py",
    "content": "from .operators import *\nimport torch, json, pandas\n\n\nclass UnifiedDataset(torch.utils.data.Dataset):\n    def __init__(\n        self,\n        base_path=None, metadata_path=None,\n        repeat=1,\n        data_file_keys=tuple(),\n        main_data_operator=lambda x: x,\n        special_operator_map=None,\n        max_data_items=None,\n    ):\n        self.base_path = base_path\n        self.metadata_path = metadata_path\n        self.repeat = repeat\n        self.data_file_keys = data_file_keys\n        self.main_data_operator = main_data_operator\n        self.cached_data_operator = LoadTorchPickle()\n        self.special_operator_map = {} if special_operator_map is None else special_operator_map\n        self.max_data_items = max_data_items\n        self.data = []\n        self.cached_data = []\n        self.load_from_cache = metadata_path is None\n        self.load_metadata(metadata_path)\n    \n    @staticmethod\n    def default_image_operator(\n        base_path=\"\",\n        max_pixels=1920*1080, height=None, width=None,\n        height_division_factor=16, width_division_factor=16,\n    ):\n        return RouteByType(operator_map=[\n            (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),\n            (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),\n        ])\n    \n    @staticmethod\n    def default_video_operator(\n        base_path=\"\",\n        max_pixels=1920*1080, height=None, width=None,\n        height_division_factor=16, width_division_factor=16,\n        num_frames=81, time_division_factor=4, time_division_remainder=1,\n        frame_rate=24, fix_frame_rate=False,\n    ):\n        return RouteByType(operator_map=[\n            (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[\n                ((\"jpg\", \"jpeg\", \"png\", \"webp\"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),\n                ((\"gif\",), LoadGIF(\n                    num_frames, time_division_factor, time_division_remainder,\n                    frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n                )),\n                ((\"mp4\", \"avi\", \"mov\", \"wmv\", \"mkv\", \"flv\", \"webm\"), LoadVideo(\n                    num_frames, time_division_factor, time_division_remainder,\n                    frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n                    frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,\n                )),\n            ])),\n        ])\n        \n    def search_for_cached_data_files(self, path):\n        for file_name in os.listdir(path):\n            subpath = os.path.join(path, file_name)\n            if os.path.isdir(subpath):\n                self.search_for_cached_data_files(subpath)\n            elif subpath.endswith(\".pth\"):\n                self.cached_data.append(subpath)\n    \n    def load_metadata(self, metadata_path):\n        if metadata_path is None:\n            print(\"No metadata_path. Searching for cached data files.\")\n            self.search_for_cached_data_files(self.base_path)\n            print(f\"{len(self.cached_data)} cached data files found.\")\n        elif metadata_path.endswith(\".json\"):\n            with open(metadata_path, \"r\") as f:\n                metadata = json.load(f)\n            self.data = metadata\n        elif metadata_path.endswith(\".jsonl\"):\n            metadata = []\n            with open(metadata_path, 'r') as f:\n                for line in f:\n                    metadata.append(json.loads(line.strip()))\n            self.data = metadata\n        else:\n            metadata = pandas.read_csv(metadata_path)\n            self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]\n\n    def __getitem__(self, data_id):\n        if self.load_from_cache:\n            data = self.cached_data[data_id % len(self.cached_data)]\n            data = self.cached_data_operator(data)\n        else:\n            data = self.data[data_id % len(self.data)].copy()\n            for key in self.data_file_keys:\n                if key in data:\n                    if key in self.special_operator_map:\n                        data[key] = self.special_operator_map[key](data[key])\n                    elif key in self.data_file_keys:\n                        data[key] = self.main_data_operator(data[key])\n        return data\n\n    def __len__(self):\n        if self.max_data_items is not None:\n            return self.max_data_items\n        elif self.load_from_cache:\n            return len(self.cached_data) * self.repeat\n        else:\n            return len(self.data) * self.repeat\n        \n    def check_data_equal(self, data1, data2):\n        # Debug only\n        if len(data1) != len(data2):\n            return False\n        for k in data1:\n            if data1[k] != data2[k]:\n                return False\n        return True\n"
  },
  {
    "path": "diffsynth/core/device/__init__.py",
    "content": "from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name\nfrom .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE\n"
  },
  {
    "path": "diffsynth/core/device/npu_compatible_device.py",
    "content": "import importlib\nimport torch\nfrom typing import Any\n\n\ndef is_torch_npu_available():\n    return importlib.util.find_spec(\"torch_npu\") is not None\n\n\nIS_CUDA_AVAILABLE = torch.cuda.is_available()\nIS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()\n\nif IS_NPU_AVAILABLE:\n    import torch_npu\n\n    torch.npu.config.allow_internal_format = False\n\n\ndef get_device_type() -> str:\n    \"\"\"Get device type based on current machine, currently only support CPU, CUDA, NPU.\"\"\"\n    if IS_CUDA_AVAILABLE:\n        device = \"cuda\"\n    elif IS_NPU_AVAILABLE:\n        device = \"npu\"\n    else:\n        device = \"cpu\"\n\n    return device\n\n\ndef get_torch_device() -> Any:\n    \"\"\"Get torch attribute based on device type, e.g. torch.cuda or torch.npu\"\"\"\n    device_name = get_device_type()\n\n    try:\n        return getattr(torch, device_name)\n    except AttributeError:\n        print(f\"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.\")\n        return torch.cuda\n\n\ndef get_device_id() -> int:\n    \"\"\"Get current device id based on device type.\"\"\"\n    return get_torch_device().current_device()\n\n\ndef get_device_name() -> str:\n    \"\"\"Get current device name based on device type.\"\"\"\n    return f\"{get_device_type()}:{get_device_id()}\"\n\n\ndef synchronize() -> None:\n    \"\"\"Execute torch synchronize operation.\"\"\"\n    get_torch_device().synchronize()\n\n\ndef empty_cache() -> None:\n    \"\"\"Execute torch empty cache operation.\"\"\"\n    get_torch_device().empty_cache()\n\n\ndef get_nccl_backend() -> str:\n    \"\"\"Return distributed communication backend type based on device type.\"\"\"\n    if IS_CUDA_AVAILABLE:\n        return \"nccl\"\n    elif IS_NPU_AVAILABLE:\n        return \"hccl\"\n    else:\n        raise RuntimeError(f\"No available distributed communication backend found on device type {get_device_type()}.\")\n\n\ndef enable_high_precision_for_bf16():\n    \"\"\"\n    Set high accumulation dtype for matmul and reduction.\n    \"\"\"\n    if IS_CUDA_AVAILABLE:\n        torch.backends.cuda.matmul.allow_tf32 = False\n        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False\n\n    if IS_NPU_AVAILABLE:\n        torch.npu.matmul.allow_tf32 = False\n        torch.npu.matmul.allow_bf16_reduced_precision_reduction = False\n\n\ndef parse_device_type(device):\n    if isinstance(device, str):\n        if device.startswith(\"cuda\"):\n            return \"cuda\"\n        elif device.startswith(\"npu\"):\n            return \"npu\"\n        else:\n            return \"cpu\"\n    elif isinstance(device, torch.device):\n        return device.type\n\n\ndef parse_nccl_backend(device_type):\n    if device_type == \"cuda\":\n        return \"nccl\"\n    elif device_type == \"npu\":\n        return \"hccl\"\n    else:\n        raise RuntimeError(f\"No available distributed communication backend found on device type {device_type}.\")\n\n\ndef get_available_device_type():\n    return get_device_type()\n"
  },
  {
    "path": "diffsynth/core/gradient/__init__.py",
    "content": "from .gradient_checkpoint import gradient_checkpoint_forward\n"
  },
  {
    "path": "diffsynth/core/gradient/gradient_checkpoint.py",
    "content": "import torch\n\n\ntry:\n    import deepspeed\n    _HAS_DEEPSPEED = True\nexcept ModuleNotFoundError:\n    _HAS_DEEPSPEED = False\n\n\ndef create_custom_forward(module):\n    def custom_forward(*inputs, **kwargs):\n        return module(*inputs, **kwargs)\n    return custom_forward\n\n\ndef create_custom_forward_use_reentrant(module):\n    def custom_forward(*inputs):\n        return module(*inputs)\n    return custom_forward\n\n\ndef judge_args_requires_grad(*args):\n    for arg in args:\n        if isinstance(arg, torch.Tensor) and arg.requires_grad:\n            return True\n    return False\n\n\ndef gradient_checkpoint_forward(\n    model,\n    use_gradient_checkpointing,\n    use_gradient_checkpointing_offload,\n    *args,\n    **kwargs,\n):\n    if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():\n        all_args = args + tuple(kwargs.values())\n        if not judge_args_requires_grad(*all_args):\n            # get the first grad_enabled tensor from un_checkpointed forward\n            model_output = model(*args, **kwargs)\n        else:\n            model_output = deepspeed.checkpointing.checkpoint(\n                create_custom_forward_use_reentrant(model),\n                *all_args,\n            )\n        return model_output\n    if use_gradient_checkpointing_offload:\n        with torch.autograd.graph.save_on_cpu():\n            model_output = torch.utils.checkpoint.checkpoint(\n                create_custom_forward(model),\n                *args,\n                **kwargs,\n                use_reentrant=False,\n            )\n    elif use_gradient_checkpointing:\n        model_output = torch.utils.checkpoint.checkpoint(\n            create_custom_forward(model),\n            *args,\n            **kwargs,\n            use_reentrant=False,\n        )\n    else:\n        model_output = model(*args, **kwargs)\n    return model_output\n"
  },
  {
    "path": "diffsynth/core/loader/__init__.py",
    "content": "from .file import load_state_dict, hash_state_dict_keys, hash_model_file\nfrom .model import load_model, load_model_with_disk_offload\nfrom .config import ModelConfig\n"
  },
  {
    "path": "diffsynth/core/loader/config.py",
    "content": "import torch, glob, os\nfrom typing import Optional, Union, Dict\nfrom dataclasses import dataclass\nfrom modelscope import snapshot_download\nfrom huggingface_hub import snapshot_download as hf_snapshot_download\nfrom typing import Optional\n\n\n@dataclass\nclass ModelConfig:\n    path: Union[str, list[str]] = None\n    model_id: str = None\n    origin_file_pattern: Union[str, list[str]] = None\n    download_source: str = None\n    local_model_path: str = None\n    skip_download: bool = None\n    offload_device: Optional[Union[str, torch.device]] = None\n    offload_dtype: Optional[torch.dtype] = None\n    onload_device: Optional[Union[str, torch.device]] = None\n    onload_dtype: Optional[torch.dtype] = None\n    preparing_device: Optional[Union[str, torch.device]] = None\n    preparing_dtype: Optional[torch.dtype] = None\n    computation_device: Optional[Union[str, torch.device]] = None\n    computation_dtype: Optional[torch.dtype] = None\n    clear_parameters: bool = False\n    state_dict: Dict[str, torch.Tensor] = None\n    \n    def check_input(self):\n        if self.path is None and self.model_id is None:\n            raise ValueError(f\"\"\"No valid model files. Please use `ModelConfig(path=\"xxx\")` or `ModelConfig(model_id=\"xxx/yyy\", origin_file_pattern=\"zzz\")`. `skip_download=True` only supports the first one.\"\"\")\n    \n    def parse_original_file_pattern(self):\n        if self.origin_file_pattern in [None, \"\", \"./\"]:\n            return \"*\"\n        elif self.origin_file_pattern.endswith(\"/\"):\n            return self.origin_file_pattern + \"*\"\n        else:\n            return self.origin_file_pattern\n        \n    def parse_download_source(self):\n        if self.download_source is None:\n            if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:\n                return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')\n            else:\n                return \"modelscope\"\n        else:\n            return self.download_source\n        \n    def parse_skip_download(self):\n        if self.skip_download is None:\n            if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:\n                if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == \"true\":\n                    return True\n                elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == \"false\":\n                    return False\n            else:\n                return False\n        else:\n            return self.skip_download\n\n    def download(self):\n        origin_file_pattern = self.parse_original_file_pattern()\n        downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))\n        download_source = self.parse_download_source()\n        if download_source.lower() == \"modelscope\":\n            snapshot_download(\n                self.model_id,\n                local_dir=os.path.join(self.local_model_path, self.model_id),\n                allow_file_pattern=origin_file_pattern,\n                ignore_file_pattern=downloaded_files,\n                local_files_only=False\n            )\n        elif download_source.lower() == \"huggingface\":\n            hf_snapshot_download(\n                self.model_id,\n                local_dir=os.path.join(self.local_model_path, self.model_id),\n                allow_patterns=origin_file_pattern,\n                ignore_patterns=downloaded_files,\n                local_files_only=False\n            )\n        else:\n            raise ValueError(\"`download_source` should be `modelscope` or `huggingface`.\")\n        \n    def require_downloading(self):\n        if self.path is not None:\n            return False\n        skip_download = self.parse_skip_download()\n        return not skip_download\n    \n    def reset_local_model_path(self):\n        if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:\n            self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')\n        elif self.local_model_path is None:\n            self.local_model_path = \"./models\"\n\n    def download_if_necessary(self):\n        self.check_input()\n        self.reset_local_model_path()\n        if self.require_downloading():\n            self.download()\n        if self.path is None:\n            if self.origin_file_pattern in [None, \"\", \"./\"]:\n                self.path = os.path.join(self.local_model_path, self.model_id)\n            else:\n                self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))\n        if isinstance(self.path, list) and len(self.path) == 1:\n            self.path = self.path[0]\n\n    def vram_config(self):\n        return {\n            \"offload_device\": self.offload_device,\n            \"offload_dtype\": self.offload_dtype,\n            \"onload_device\": self.onload_device,\n            \"onload_dtype\": self.onload_dtype,\n            \"preparing_device\": self.preparing_device,\n            \"preparing_dtype\": self.preparing_dtype,\n            \"computation_device\": self.computation_device,\n            \"computation_dtype\": self.computation_dtype,\n        }\n"
  },
  {
    "path": "diffsynth/core/loader/file.py",
    "content": "from safetensors import safe_open\nimport torch, hashlib\n\n\ndef load_state_dict(file_path, torch_dtype=None, device=\"cpu\", pin_memory=False, verbose=0):\n    if isinstance(file_path, list):\n        state_dict = {}\n        for file_path_ in file_path:\n            state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))\n    else:\n        if verbose >= 1:\n            print(f\"Loading file [started]: {file_path}\")\n        if file_path.endswith(\".safetensors\"):\n            state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)\n        else:\n            state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)\n        # If load state dict in CPU memory, `pin_memory=True` will make `model.to(\"cuda\")` faster.\n        if pin_memory:\n            for i in state_dict:\n                state_dict[i] = state_dict[i].pin_memory()\n        if verbose >= 1:\n            print(f\"Loading file [done]: {file_path}\")\n    return state_dict\n\n\ndef load_state_dict_from_safetensors(file_path, torch_dtype=None, device=\"cpu\"):\n    state_dict = {}\n    with safe_open(file_path, framework=\"pt\", device=str(device)) as f:\n        for k in f.keys():\n            state_dict[k] = f.get_tensor(k)\n            if torch_dtype is not None:\n                state_dict[k] = state_dict[k].to(torch_dtype)\n    return state_dict\n\n\ndef load_state_dict_from_bin(file_path, torch_dtype=None, device=\"cpu\"):\n    state_dict = torch.load(file_path, map_location=device, weights_only=True)\n    if len(state_dict) == 1:\n        if \"state_dict\" in state_dict:\n            state_dict = state_dict[\"state_dict\"]\n        elif \"module\" in state_dict:\n            state_dict = state_dict[\"module\"]\n        elif \"model_state\" in state_dict:\n            state_dict = state_dict[\"model_state\"]\n    if torch_dtype is not None:\n        for i in state_dict:\n            if isinstance(state_dict[i], torch.Tensor):\n                state_dict[i] = state_dict[i].to(torch_dtype)\n    return state_dict\n\n\ndef convert_state_dict_keys_to_single_str(state_dict, with_shape=True):\n    keys = []\n    for key, value in state_dict.items():\n        if isinstance(key, str):\n            if isinstance(value, torch.Tensor):\n                if with_shape:\n                    shape = \"_\".join(map(str, list(value.shape)))\n                    keys.append(key + \":\" + shape)\n                keys.append(key)\n            elif isinstance(value, dict):\n                keys.append(key + \"|\" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))\n    keys.sort()\n    keys_str = \",\".join(keys)\n    return keys_str\n\n\ndef hash_state_dict_keys(state_dict, with_shape=True):\n    keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)\n    keys_str = keys_str.encode(encoding=\"UTF-8\")\n    return hashlib.md5(keys_str).hexdigest()\n\n\ndef load_keys_dict(file_path):\n    if isinstance(file_path, list):\n        state_dict = {}\n        for file_path_ in file_path:\n            state_dict.update(load_keys_dict(file_path_))\n        return state_dict\n    if file_path.endswith(\".safetensors\"):\n        return load_keys_dict_from_safetensors(file_path)\n    else:\n        return load_keys_dict_from_bin(file_path)\n\n\ndef load_keys_dict_from_safetensors(file_path):\n    keys_dict = {}\n    with safe_open(file_path, framework=\"pt\", device=\"cpu\") as f:\n        for k in f.keys():\n            keys_dict[k] = f.get_slice(k).get_shape()\n    return keys_dict\n\n\ndef convert_state_dict_to_keys_dict(state_dict):\n    keys_dict = {}\n    for k, v in state_dict.items():\n        if isinstance(v, torch.Tensor):\n            keys_dict[k] = list(v.shape)\n        else:\n            keys_dict[k] = convert_state_dict_to_keys_dict(v)\n    return keys_dict\n\n\ndef load_keys_dict_from_bin(file_path):\n    state_dict = load_state_dict_from_bin(file_path)\n    keys_dict = convert_state_dict_to_keys_dict(state_dict)\n    return keys_dict\n\n\ndef convert_keys_dict_to_single_str(state_dict, with_shape=True):\n    keys = []\n    for key, value in state_dict.items():\n        if isinstance(key, str):\n            if isinstance(value, dict):\n                keys.append(key + \"|\" + convert_keys_dict_to_single_str(value, with_shape=with_shape))\n            else:\n                if with_shape:\n                    shape = \"_\".join(map(str, list(value)))\n                    keys.append(key + \":\" + shape)\n                keys.append(key)\n    keys.sort()\n    keys_str = \",\".join(keys)\n    return keys_str\n\n\ndef hash_model_file(path, with_shape=True):\n    keys_dict = load_keys_dict(path)\n    keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)\n    keys_str = keys_str.encode(encoding=\"UTF-8\")\n    return hashlib.md5(keys_str).hexdigest()\n"
  },
  {
    "path": "diffsynth/core/loader/model.py",
    "content": "from ..vram.initialization import skip_model_initialization\nfrom ..vram.disk_map import DiskMap\nfrom ..vram.layers import enable_vram_management\nfrom .file import load_state_dict\nimport torch\nfrom contextlib import contextmanager\nfrom transformers.integrations import is_deepspeed_zero3_enabled\nfrom transformers.utils import ContextManagers\n\n\ndef load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device=\"cpu\", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):\n    config = {} if config is None else config\n    with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):\n        model = model_class(**config)\n    # What is `module_map`?\n    # This is a module mapping table for VRAM management.\n    if module_map is not None:\n        devices = [vram_config[\"offload_device\"], vram_config[\"onload_device\"], vram_config[\"preparing_device\"], vram_config[\"computation_device\"]]\n        device = [d for d in devices if d != \"disk\"][0]\n        dtypes = [vram_config[\"offload_dtype\"], vram_config[\"onload_dtype\"], vram_config[\"preparing_dtype\"], vram_config[\"computation_dtype\"]]\n        dtype = [d for d in dtypes if d != \"disk\"][0]\n        if vram_config[\"offload_device\"] != \"disk\":\n            if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)\n            if state_dict_converter is not None:\n                state_dict = state_dict_converter(state_dict)\n            else:\n                state_dict = {i: state_dict[i] for i in state_dict}\n            model.load_state_dict(state_dict, assign=True)\n            model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)\n        else:\n            disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)\n            model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)\n    else:\n        # Why do we use `DiskMap`?\n        # Sometimes a model file contains multiple models,\n        # and DiskMap can load only the parameters of a single model,\n        # avoiding the need to load all parameters in the file.\n        if state_dict is not None:\n            pass\n        elif use_disk_map:\n            state_dict = DiskMap(path, device, torch_dtype=torch_dtype)\n        else:\n            state_dict = load_state_dict(path, torch_dtype, device)\n        # Why do we use `state_dict_converter`?\n        # Some models are saved in complex formats,\n        # and we need to convert the state dict into the appropriate format.\n        if state_dict_converter is not None:\n            state_dict = state_dict_converter(state_dict)\n        else:\n            state_dict = {i: state_dict[i] for i in state_dict}\n        # Why does DeepSpeed ZeRO Stage 3 need to be handled separately?\n        # Because at this stage, model parameters are partitioned across multiple GPUs.\n        # Loading them directly could lead to excessive GPU memory consumption.\n        if is_deepspeed_zero3_enabled():\n            from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model\n            _load_state_dict_into_zero3_model(model, state_dict)\n        else:\n            model.load_state_dict(state_dict, assign=True)\n        # Why do we call `to()`?\n        # Because some models override the behavior of `to()`,\n        # especially those from libraries like Transformers.\n        model = model.to(dtype=torch_dtype, device=device)\n    if hasattr(model, \"eval\"):\n        model = model.eval()\n    return model\n\n\ndef load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device=\"cpu\", state_dict_converter=None, module_map=None):\n    if isinstance(path, str):\n        path = [path]\n    config = {} if config is None else config\n    with skip_model_initialization():\n        model = model_class(**config)\n    if hasattr(model, \"eval\"):\n        model = model.eval()\n    disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)\n    vram_config = {\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.float8_e4m3fn,\n        \"preparing_device\": device,\n        \"computation_dtype\": torch_dtype,\n        \"computation_device\": device,\n    }\n    enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)\n    return model\n\n\ndef get_init_context(torch_dtype, device):\n    if is_deepspeed_zero3_enabled():\n        from transformers.modeling_utils import set_zero3_state\n        import deepspeed\n        # Why do we use \"deepspeed.zero.Init\"?\n        # Weight segmentation of the model can be performed on the CPU side\n        # and loading the segmented weights onto the computing card\n        init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]\n    else:\n        # Why do we use `skip_model_initialization`?\n        # It skips the random initialization of model parameters,\n        # thereby speeding up model loading and avoiding excessive memory usage.\n        init_contexts = [skip_model_initialization()]\n\n    return init_contexts\n"
  },
  {
    "path": "diffsynth/core/npu_patch/npu_fused_operator.py",
    "content": "import torch\nfrom ..device.npu_compatible_device import get_device_type\ntry:\n    import torch_npu\nexcept:\n    pass\n\n\ndef rms_norm_forward_npu(self, hidden_states):\n    \"npu rms fused operator for RMSNorm.forward from diffsynth\\models\\general_modules.py\"\n    if hidden_states.dtype != self.weight.dtype:\n        hidden_states = hidden_states.to(self.weight.dtype)\n    return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]\n\n\ndef rms_norm_forward_transformers_npu(self, hidden_states):\n    \"npu rms fused operator for transformers\"\n    if hidden_states.dtype != self.weight.dtype:\n        hidden_states = hidden_states.to(self.weight.dtype)\n    return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]\n\n\ndef rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):\n    \"npu rope fused operator for Zimage\"\n    with torch.amp.autocast(get_device_type(), enabled=False):\n        freqs_cis = freqs_cis.unsqueeze(2)\n        cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)\n        cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)\n        sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)\n        return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode=\"interleave\").to(x_in)"
  },
  {
    "path": "diffsynth/core/vram/__init__.py",
    "content": "from .initialization import skip_model_initialization\nfrom .layers import *\n"
  },
  {
    "path": "diffsynth/core/vram/disk_map.py",
    "content": "from safetensors import safe_open\nimport torch, os\n\n\nclass SafetensorsCompatibleTensor:\n    def __init__(self, tensor):\n        self.tensor = tensor\n    \n    def get_shape(self):\n        return list(self.tensor.shape)\n\n\nclass SafetensorsCompatibleBinaryLoader:\n    def __init__(self, path, device):\n        print(\"Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.\")\n        self.state_dict = torch.load(path, weights_only=True, map_location=device)\n        \n    def keys(self):\n        return self.state_dict.keys()\n    \n    def get_tensor(self, name):\n        return self.state_dict[name]\n    \n    def get_slice(self, name):\n        return SafetensorsCompatibleTensor(self.state_dict[name])\n\n\nclass DiskMap:\n\n    def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):\n        self.path = path if isinstance(path, list) else [path]\n        self.device = device\n        self.torch_dtype = torch_dtype\n        if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:\n            self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))\n        else:\n            self.buffer_size = buffer_size\n        self.files = []\n        self.flush_files()\n        self.name_map = {}\n        for file_id, file in enumerate(self.files):\n            for name in file.keys():\n                self.name_map[name] = file_id\n        self.rename_dict = self.fetch_rename_dict(state_dict_converter)\n        \n    def flush_files(self):\n        if len(self.files) == 0:\n            for path in self.path:\n                if path.endswith(\".safetensors\"):\n                    self.files.append(safe_open(path, framework=\"pt\", device=str(self.device)))\n                else:\n                    self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))\n        else:\n            for i, path in enumerate(self.path):\n                if path.endswith(\".safetensors\"):\n                    self.files[i] = safe_open(path, framework=\"pt\", device=str(self.device))\n        self.num_params = 0\n\n    def __getitem__(self, name):\n        if self.rename_dict is not None: name = self.rename_dict[name]\n        file_id = self.name_map[name]\n        param = self.files[file_id].get_tensor(name)\n        if self.torch_dtype is not None and isinstance(param, torch.Tensor):\n            param = param.to(self.torch_dtype)\n        if isinstance(param, torch.Tensor) and param.device == \"cpu\":\n            param = param.clone()\n        if isinstance(param, torch.Tensor):\n            self.num_params += param.numel()\n        if self.num_params > self.buffer_size:\n            self.flush_files()\n        return param\n\n    def fetch_rename_dict(self, state_dict_converter):\n        if state_dict_converter is None:\n            return None\n        state_dict = {}\n        for file in self.files:\n            for name in file.keys():\n                state_dict[name] = name\n        state_dict = state_dict_converter(state_dict)\n        return state_dict\n    \n    def __iter__(self):\n        if self.rename_dict is not None:\n            return self.rename_dict.__iter__()\n        else:\n            return self.name_map.__iter__()\n    \n    def __contains__(self, x):\n        if self.rename_dict is not None:\n            return x in self.rename_dict\n        else:\n            return x in self.name_map\n"
  },
  {
    "path": "diffsynth/core/vram/initialization.py",
    "content": "import torch\nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef skip_model_initialization(device=torch.device(\"meta\")):\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        if param is not None:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n\n    old_register_parameter = torch.nn.Module.register_parameter\n    torch.nn.Module.register_parameter = register_empty_parameter\n    try:\n        yield\n    finally:\n        torch.nn.Module.register_parameter = old_register_parameter\n"
  },
  {
    "path": "diffsynth/core/vram/layers.py",
    "content": "import torch, copy\nfrom typing import Union\nfrom .initialization import skip_model_initialization\nfrom .disk_map import DiskMap\nfrom ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE\n\n\nclass AutoTorchModule(torch.nn.Module):\n\n    def __init__(\n        self,\n        offload_dtype: torch.dtype = None,\n        offload_device: Union[str, torch.device] = None,\n        onload_dtype: torch.dtype = None,\n        onload_device: Union[str, torch.device] = None,\n        preparing_dtype: torch.dtype = None,\n        preparing_device: Union[str, torch.device] = None,\n        computation_dtype: torch.dtype = None,\n        computation_device: Union[str, torch.device] = None,\n        vram_limit: float = None,\n    ):\n        super().__init__()\n        self.set_dtype_and_device(\n            offload_dtype,\n            offload_device,\n            onload_dtype,\n            onload_device,\n            preparing_dtype,\n            preparing_device,\n            computation_dtype,\n            computation_device,\n            vram_limit,\n        )\n        self.state = 0\n        self.name = \"\"\n        self.computation_device_type = parse_device_type(self.computation_device)\n\n    def set_dtype_and_device(\n        self,\n        offload_dtype: torch.dtype = None,\n        offload_device: Union[str, torch.device] = None,\n        onload_dtype: torch.dtype = None,\n        onload_device: Union[str, torch.device] = None,\n        preparing_dtype: torch.dtype = None,\n        preparing_device: Union[str, torch.device] = None,\n        computation_dtype: torch.dtype = None,\n        computation_device: Union[str, torch.device] = None,\n        vram_limit: float = None,\n    ):\n        self.offload_dtype = offload_dtype or computation_dtype\n        self.offload_device = offload_device or computation_dtype\n        self.onload_dtype = onload_dtype or computation_dtype\n        self.onload_device = onload_device or computation_dtype\n        self.preparing_dtype = preparing_dtype or computation_dtype\n        self.preparing_device = preparing_device or computation_dtype\n        self.computation_dtype = computation_dtype\n        self.computation_device = computation_device\n        self.vram_limit = vram_limit\n\n    def cast_to(self, weight, dtype, device):\n        r = torch.empty_like(weight, dtype=dtype, device=device)\n        r.copy_(weight)\n        return r\n\n    def check_free_vram(self):\n        device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()\n        gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)\n        used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)\n        return used_memory < self.vram_limit\n\n    def offload(self):\n        if self.state != 0:\n            self.to(dtype=self.offload_dtype, device=self.offload_device)\n            self.state = 0\n\n    def onload(self):\n        if self.state != 1:\n            self.to(dtype=self.onload_dtype, device=self.onload_device)\n            self.state = 1\n            \n    def param_name(self, name):\n        if self.name == \"\":\n            return name\n        else:\n            return self.name + \".\" + name\n\n\nclass AutoWrappedModule(AutoTorchModule):\n\n    def __init__(\n        self,\n        module: torch.nn.Module,\n        offload_dtype: torch.dtype = None,\n        offload_device: Union[str, torch.device] = None,\n        onload_dtype: torch.dtype = None,\n        onload_device: Union[str, torch.device] = None,\n        preparing_dtype: torch.dtype = None,\n        preparing_device: Union[str, torch.device] = None,\n        computation_dtype: torch.dtype = None,\n        computation_device: Union[str, torch.device] = None,\n        vram_limit: float = None,\n        name: str = \"\",\n        disk_map: DiskMap = None,\n        **kwargs\n    ):\n        super().__init__(\n            offload_dtype,\n            offload_device,\n            onload_dtype,\n            onload_device,\n            preparing_dtype,\n            preparing_device,\n            computation_dtype,\n            computation_device,\n            vram_limit,\n        )\n        self.module = module\n        if offload_dtype == \"disk\":\n            self.name = name\n            self.disk_map = disk_map\n            self.required_params = [name for name, _ in self.module.named_parameters()]\n            self.disk_offload = True\n        else:\n            self.disk_offload = False\n            \n    def load_from_disk(self, torch_dtype, device, copy_module=False):\n        if copy_module:\n            module = copy.deepcopy(self.module)\n        else:\n            module = self.module\n        state_dict = {}\n        for name in self.required_params:\n            param = self.disk_map[self.param_name(name)]\n            param = param.to(dtype=torch_dtype, device=device)\n            state_dict[name] = param\n        module.load_state_dict(state_dict, assign=True)\n        module.to(dtype=torch_dtype, device=device)\n        return module\n    \n    def offload_to_disk(self, model: torch.nn.Module):\n        for buf in model.buffers():\n            # If there are some parameters are registed in buffers (not in state dict),\n            # We cannot offload the model.\n            for children in model.children():\n                self.offload_to_disk(children)\n            break\n        else:\n            model.to(\"meta\")\n\n    def offload(self):\n        # offload / onload / preparing -> offload\n        if self.state != 0:\n            if self.disk_offload:\n                self.offload_to_disk(self.module)\n            else:\n                self.to(dtype=self.offload_dtype, device=self.offload_device)\n            self.state = 0\n\n    def onload(self):\n        # offload / onload / preparing -> onload\n        if self.state < 1:\n            if self.disk_offload and self.onload_device != \"disk\" and self.offload_device == \"disk\":\n                self.load_from_disk(self.onload_dtype, self.onload_device)\n            elif self.onload_device != \"disk\":\n                self.to(dtype=self.onload_dtype, device=self.onload_device)\n            self.state = 1\n            \n    def preparing(self):\n        # onload / preparing -> preparing\n        if self.state != 2:\n            if self.disk_offload and self.preparing_device != \"disk\" and self.onload_device == \"disk\":\n                self.load_from_disk(self.preparing_dtype, self.preparing_device)\n            elif self.preparing_device != \"disk\":\n                self.to(dtype=self.preparing_dtype, device=self.preparing_device)\n            self.state = 2\n\n    def cast_to(self, module, dtype, device):\n        return copy.deepcopy(module).to(dtype=dtype, device=device)\n            \n    def computation(self):\n        # onload / preparing -> computation (temporary)\n        if self.state == 2:\n            torch_dtype, device = self.preparing_dtype, self.preparing_device\n        else:\n            torch_dtype, device = self.onload_dtype, self.onload_device\n        if torch_dtype == self.computation_dtype and device == self.computation_device:\n            module = self.module\n        elif self.disk_offload and device == \"disk\":\n            module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)\n        else:\n            module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)\n        return module\n\n    def forward(self, *args, **kwargs):\n        if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):\n            self.preparing()\n        module = self.computation()\n        return module(*args, **kwargs)\n    \n    def __getattr__(self, name):\n        if name in self.__dict__ or name == \"module\":\n            return super().__getattr__(name)\n        else:\n            return getattr(self.module, name)\n\n\nclass AutoWrappedNonRecurseModule(AutoWrappedModule):\n\n    def __init__(\n        self,\n        module: torch.nn.Module,\n        offload_dtype: torch.dtype = None,\n        offload_device: Union[str, torch.device] = None,\n        onload_dtype: torch.dtype = None,\n        onload_device: Union[str, torch.device] = None,\n        preparing_dtype: torch.dtype = None,\n        preparing_device: Union[str, torch.device] = None,\n        computation_dtype: torch.dtype = None,\n        computation_device: Union[str, torch.device] = None,\n        vram_limit: float = None,\n        name: str = \"\",\n        disk_map: DiskMap = None,\n        **kwargs\n    ):\n        super().__init__(\n            module,\n            offload_dtype,\n            offload_device,\n            onload_dtype,\n            onload_device,\n            preparing_dtype,\n            preparing_device,\n            computation_dtype,\n            computation_device,\n            vram_limit,\n            name,\n            disk_map,\n            **kwargs\n        )\n        if self.disk_offload:\n            self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]\n            \n    def load_from_disk(self, torch_dtype, device, copy_module=False):\n        if copy_module:\n            module = copy.deepcopy(self.module)\n        else:\n            module = self.module\n        state_dict = {}\n        for name in self.required_params:\n            param = self.disk_map[self.param_name(name)]\n            param = param.to(dtype=torch_dtype, device=device)\n            state_dict[name] = param\n        module.load_state_dict(state_dict, assign=True, strict=False)\n        return module\n    \n    def offload_to_disk(self, model: torch.nn.Module):\n        for name in self.required_params:\n            getattr(self, name).to(\"meta\")\n    \n    def cast_to(self, module, dtype, device):\n        # Parameter casting is implemented in the model architecture.\n        return module\n    \n    def __getattr__(self, name):\n        if name in self.__dict__ or name == \"module\":\n            return super().__getattr__(name)\n        else:\n            return getattr(self.module, name)\n\n\nclass AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):\n    def __init__(\n        self,\n        module: torch.nn.Linear,\n        offload_dtype: torch.dtype = None,\n        offload_device: Union[str, torch.device] = None,\n        onload_dtype: torch.dtype = None,\n        onload_device: Union[str, torch.device] = None,\n        preparing_dtype: torch.dtype = None,\n        preparing_device: Union[str, torch.device] = None,\n        computation_dtype: torch.dtype = None,\n        computation_device: Union[str, torch.device] = None,\n        vram_limit: float = None,\n        name: str = \"\",\n        disk_map: DiskMap = None,\n        **kwargs\n    ):\n        with skip_model_initialization():\n            super().__init__(\n                in_features=module.in_features,\n                out_features=module.out_features,\n                bias=module.bias is not None,\n            )\n        self.set_dtype_and_device(\n            offload_dtype,\n            offload_device,\n            onload_dtype,\n            onload_device,\n            preparing_dtype,\n            preparing_device,\n            computation_dtype,\n            computation_device,\n            vram_limit,\n        )\n        self.weight = module.weight\n        self.bias = module.bias\n        self.state = 0\n        self.name = name\n        self.lora_A_weights = []\n        self.lora_B_weights = []\n        self.lora_merger = None\n        self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]\n        self.computation_device_type = parse_device_type(self.computation_device)\n        \n        if offload_dtype == \"disk\":\n            self.disk_map = disk_map\n            self.disk_offload = True\n        else:\n            self.disk_offload = False\n    \n    def fp8_linear(\n        self,\n        input: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor = None,\n    ) -> torch.Tensor:\n        device = input.device\n        origin_dtype = input.dtype\n        origin_shape = input.shape\n        input = input.reshape(-1, origin_shape[-1])\n\n        x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values\n        fp8_max = 448.0\n        # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.\n        # To avoid overflow and ensure numerical compatibility during FP8 computation,\n        # we scale down the input by 2.0 in advance.\n        # This scaling will be compensated later during the final result scaling.\n        if self.computation_dtype == torch.float8_e4m3fnuz:\n            fp8_max = fp8_max / 2.0\n        scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)\n        scale_b = torch.ones((weight.shape[0], 1)).to(device=device)\n        input = input / (scale_a + 1e-8)\n        input = input.to(self.computation_dtype)\n        weight = weight.to(self.computation_dtype)\n        bias = bias.to(torch.bfloat16)\n\n        result = torch._scaled_mm(\n            input,\n            weight.T,\n            scale_a=scale_a,\n            scale_b=scale_b.T,\n            bias=bias,\n            out_dtype=origin_dtype,\n        )\n        new_shape = origin_shape[:-1] + result.shape[-1:]\n        result = result.reshape(new_shape)\n        return result\n            \n    def load_from_disk(self, torch_dtype, device, assign=True):\n        weight = self.disk_map[self.name + \".weight\"].to(dtype=torch_dtype, device=device)\n        bias = None if self.bias is None else self.disk_map[self.name + \".bias\"].to(dtype=torch_dtype, device=device)\n        if assign:\n            state_dict = {\"weight\": weight}\n            if bias is not None: state_dict[\"bias\"] = bias\n            self.load_state_dict(state_dict, assign=True)\n        return weight, bias\n    \n    def offload(self):\n        # offload / onload / preparing -> offload\n        if self.state != 0:\n            if self.disk_offload:\n                self.to(\"meta\")\n            else:\n                self.to(dtype=self.offload_dtype, device=self.offload_device)\n            self.state = 0\n\n    def onload(self):\n        # offload / onload / preparing -> onload\n        if self.state < 1:\n            if self.disk_offload and self.onload_device != \"disk\" and self.offload_device == \"disk\":\n                self.load_from_disk(self.onload_dtype, self.onload_device)\n            elif self.onload_device != \"disk\":\n                self.to(dtype=self.onload_dtype, device=self.onload_device)\n            self.state = 1\n            \n    def preparing(self):\n        # onload / preparing -> preparing\n        if self.state != 2:\n            if self.disk_offload and self.preparing_device != \"disk\" and self.onload_device == \"disk\":\n                self.load_from_disk(self.preparing_dtype, self.preparing_device)\n            elif self.preparing_device != \"disk\":\n                self.to(dtype=self.preparing_dtype, device=self.preparing_device)\n            self.state = 2\n            \n    def computation(self):\n        # onload / preparing -> computation (temporary)\n        if self.state == 2:\n            torch_dtype, device = self.preparing_dtype, self.preparing_device\n        else:\n            torch_dtype, device = self.onload_dtype, self.onload_device\n        if torch_dtype == self.computation_dtype and device == self.computation_device:\n            weight, bias = self.weight, self.bias\n        elif self.disk_offload and device == \"disk\":\n            weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)\n        else:\n            weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)\n            bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)\n        return weight, bias\n\n    def linear_forward(self, x, weight, bias):\n        if self.enable_fp8:\n            out = self.fp8_linear(x, weight, bias)\n        else:\n            out = torch.nn.functional.linear(x, weight, bias)\n        return out\n\n    def lora_forward(self, x, out):\n        if self.lora_merger is None:\n            for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):\n                out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)\n        else:\n            lora_output = []\n            for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):\n                lora_output.append(x @ lora_A.T @ lora_B.T)\n            lora_output = torch.stack(lora_output)\n            out = self.lora_merger(out, lora_output)\n        return out\n    \n    def forward(self, x, *args, **kwargs):\n        if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):\n            self.preparing()\n        weight, bias = self.computation()\n        out = self.linear_forward(x, weight, bias)\n        if len(self.lora_A_weights) > 0:\n            out = self.lora_forward(x, out)\n        return out\n\n\ndef enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix=\"\", disk_map=None, **kwargs):\n    if isinstance(model, AutoWrappedNonRecurseModule):\n        model = model.module\n    for name, module in model.named_children():\n        layer_name = name if name_prefix == \"\" else name_prefix + \".\" + name\n        for source_module, target_module in module_map.items():\n            if isinstance(module, source_module):\n                module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)\n                if isinstance(module_, AutoWrappedNonRecurseModule):\n                    enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)\n                setattr(model, name, module_)\n                break\n        else:\n            enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)\n\n\ndef fill_vram_config(model, vram_config):\n    vram_config_ = vram_config.copy()\n    vram_config_[\"onload_dtype\"] = vram_config[\"computation_dtype\"]\n    vram_config_[\"onload_device\"] = vram_config[\"computation_device\"]\n    vram_config_[\"preparing_dtype\"] = vram_config[\"computation_dtype\"]\n    vram_config_[\"preparing_device\"] = vram_config[\"computation_device\"]\n    for k in vram_config:\n        if vram_config[k] != vram_config_[k]:\n            print(f\"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}\")\n            break\n    return vram_config_\n\n\ndef enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):\n    for source_module, target_module in module_map.items():\n        # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.\n        if isinstance(model, source_module):\n            vram_config = fill_vram_config(model, vram_config)\n            model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)\n            break\n    else:\n        enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)\n    # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.\n    model.vram_management_enabled = True\n    return model\n"
  },
  {
    "path": "diffsynth/diffusion/__init__.py",
    "content": "from .flow_match import FlowMatchScheduler\nfrom .training_module import DiffusionTrainingModule\nfrom .logger import ModelLogger\nfrom .runner import launch_training_task, launch_data_process_task\nfrom .parsers import *\nfrom .loss import *\n"
  },
  {
    "path": "diffsynth/diffusion/base_pipeline.py",
    "content": "from PIL import Image\nimport torch\nimport numpy as np\nfrom einops import repeat, reduce\nfrom typing import Union\nfrom ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..utils.lora import GeneralLoRALoader\nfrom ..models.model_loader import ModelPool\nfrom ..utils.controlnet import ControlNetInput\nfrom ..core.device import get_device_name, IS_NPU_AVAILABLE\n\n\nclass PipelineUnit:\n    def __init__(\n        self,\n        seperate_cfg: bool = False,\n        take_over: bool = False,\n        input_params: tuple[str] = None,\n        output_params: tuple[str] = None,\n        input_params_posi: dict[str, str] = None,\n        input_params_nega: dict[str, str] = None,\n        onload_model_names: tuple[str] = None\n    ):\n        self.seperate_cfg = seperate_cfg\n        self.take_over = take_over\n        self.input_params = input_params\n        self.output_params = output_params\n        self.input_params_posi = input_params_posi\n        self.input_params_nega = input_params_nega\n        self.onload_model_names = onload_model_names\n\n    def fetch_input_params(self):\n        params = []\n        if self.input_params is not None:\n            for param in self.input_params:\n                params.append(param)\n        if self.input_params_posi is not None:\n            for _, param in self.input_params_posi.items():\n                params.append(param)\n        if self.input_params_nega is not None:\n            for _, param in self.input_params_nega.items():\n                params.append(param)\n        params = sorted(list(set(params)))\n        return params\n    \n    def fetch_output_params(self):\n        params = []\n        if self.output_params is not None:\n            for param in self.output_params:\n                params.append(param)\n        return params\n\n    def process(self, pipe, **kwargs) -> dict:\n        return {}\n    \n    def post_process(self, pipe, **kwargs) -> dict:\n        return {}\n\n\nclass BasePipeline(torch.nn.Module):\n\n    def __init__(\n        self,\n        device=get_device_type(), torch_dtype=torch.float16,\n        height_division_factor=64, width_division_factor=64,\n        time_division_factor=None, time_division_remainder=None,\n    ):\n        super().__init__()\n        # The device and torch_dtype is used for the storage of intermediate variables, not models.\n        self.device = device\n        self.torch_dtype = torch_dtype\n        self.device_type = parse_device_type(device)\n        # The following parameters are used for shape check.\n        self.height_division_factor = height_division_factor\n        self.width_division_factor = width_division_factor\n        self.time_division_factor = time_division_factor\n        self.time_division_remainder = time_division_remainder\n        # VRAM management\n        self.vram_management_enabled = False\n        # Pipeline Unit Runner\n        self.unit_runner = PipelineUnitRunner()\n        # LoRA Loader\n        self.lora_loader = GeneralLoRALoader\n        \n        \n    def to(self, *args, **kwargs):\n        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)\n        if device is not None:\n            self.device = device\n        if dtype is not None:\n            self.torch_dtype = dtype\n        super().to(*args, **kwargs)\n        return self\n\n\n    def check_resize_height_width(self, height, width, num_frames=None, verbose=1):\n        # Shape check\n        if height % self.height_division_factor != 0:\n            height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor\n            if verbose > 0:\n                print(f\"height % {self.height_division_factor} != 0. We round it up to {height}.\")\n        if width % self.width_division_factor != 0:\n            width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor\n            if verbose > 0:\n                print(f\"width % {self.width_division_factor} != 0. We round it up to {width}.\")\n        if num_frames is None:\n            return height, width\n        else:\n            if num_frames % self.time_division_factor != self.time_division_remainder:\n                num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder\n                if verbose > 0:\n                    print(f\"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.\")\n            return height, width, num_frames\n\n\n    def preprocess_image(self, image, torch_dtype=None, device=None, pattern=\"B C H W\", min_value=-1, max_value=1):\n        # Transform a PIL.Image to torch.Tensor\n        image = torch.Tensor(np.array(image, dtype=np.float32))\n        image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)\n        image = image * ((max_value - min_value) / 255) + min_value\n        image = repeat(image, f\"H W C -> {pattern}\", **({\"B\": 1} if \"B\" in pattern else {}))\n        return image\n\n\n    def preprocess_video(self, video, torch_dtype=None, device=None, pattern=\"B C T H W\", min_value=-1, max_value=1):\n        # Transform a list of PIL.Image to torch.Tensor\n        video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]\n        video = torch.stack(video, dim=pattern.index(\"T\") // 2)\n        return video\n\n\n    def vae_output_to_image(self, vae_output, pattern=\"B C H W\", min_value=-1, max_value=1):\n        # Transform a torch.Tensor to PIL.Image\n        if pattern != \"H W C\":\n            vae_output = reduce(vae_output, f\"{pattern} -> H W C\", reduction=\"mean\")\n        image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)\n        image = image.to(device=\"cpu\", dtype=torch.uint8)\n        image = Image.fromarray(image.numpy())\n        return image\n\n\n    def vae_output_to_video(self, vae_output, pattern=\"B C T H W\", min_value=-1, max_value=1):\n        # Transform a torch.Tensor to list of PIL.Image\n        if pattern != \"T H W C\":\n            vae_output = reduce(vae_output, f\"{pattern} -> T H W C\", reduction=\"mean\")\n        video = [self.vae_output_to_image(image, pattern=\"H W C\", min_value=min_value, max_value=max_value) for image in vae_output]\n        return video\n\n    def output_audio_format_check(self, audio_output):\n        # output standard foramt: [C, T], output dtype: float()\n        # remove batch dim\n        if audio_output.ndim == 3:\n            audio_output = audio_output.squeeze(0)\n        return audio_output.float()\n\n    def load_models_to_device(self, model_names):\n        if self.vram_management_enabled:\n            # offload models\n            for name, model in self.named_children():\n                if name not in model_names:\n                    if hasattr(model, \"vram_management_enabled\") and model.vram_management_enabled:\n                        if hasattr(model, \"offload\"):\n                            model.offload()\n                        else:\n                            for module in model.modules():\n                                if hasattr(module, \"offload\"):\n                                    module.offload()\n            getattr(torch, self.device_type).empty_cache()\n            # onload models\n            for name, model in self.named_children():\n                if name in model_names:\n                    if hasattr(model, \"vram_management_enabled\") and model.vram_management_enabled:\n                        if hasattr(model, \"onload\"):\n                            model.onload()\n                        else:\n                            for module in model.modules():\n                                if hasattr(module, \"onload\"):\n                                    module.onload()\n\n\n    def generate_noise(self, shape, seed=None, rand_device=\"cpu\", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):\n        # Initialize Gaussian noise\n        generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)\n        noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)\n        noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)\n        return noise\n\n        \n    def get_vram(self):\n        device = self.device if not IS_NPU_AVAILABLE else get_device_name()\n        return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)\n    \n    def get_module(self, model, name):\n        if \".\" in name:\n            name, suffix = name[:name.index(\".\")], name[name.index(\".\") + 1:]\n            if name.isdigit():\n                return self.get_module(model[int(name)], suffix)\n            else:\n                return self.get_module(getattr(model, name), suffix)\n        else:\n            return getattr(model, name)\n    \n    def freeze_except(self, model_names):\n        self.eval()\n        self.requires_grad_(False)\n        for name in model_names:\n            module = self.get_module(self, name)\n            if module is None:\n                print(f\"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.\")\n                continue\n            module.train()\n            module.requires_grad_(True)\n                \n    \n    def blend_with_mask(self, base, addition, mask):\n        return base * (1 - mask) + addition * mask\n    \n    \n    def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):\n        timestep = scheduler.timesteps[progress_id]\n        if inpaint_mask is not None:\n            noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)\n            noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)\n        latents_next = scheduler.step(noise_pred, timestep, latents)\n        return latents_next\n    \n    \n    def split_pipeline_units(self, model_names: list[str]):\n        return PipelineUnitGraph().split_pipeline_units(self.units, model_names)\n    \n    \n    def flush_vram_management_device(self, device):\n        for module in self.modules():\n            if isinstance(module, AutoTorchModule):\n                module.offload_device = device\n                module.onload_device = device\n                module.preparing_device = device\n                module.computation_device = device\n                \n    \n    def load_lora(\n        self,\n        module: torch.nn.Module,\n        lora_config: Union[ModelConfig, str] = None,\n        alpha=1,\n        hotload=None,\n        state_dict=None,\n        verbose=1,\n    ):\n        if state_dict is None:\n            if isinstance(lora_config, str):\n                lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)\n            else:\n                lora_config.download_if_necessary()\n                lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)\n        else:\n            lora = state_dict\n        lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)\n        lora = lora_loader.convert_state_dict(lora)\n        if hotload is None:\n            hotload = hasattr(module, \"vram_management_enabled\") and getattr(module, \"vram_management_enabled\")\n        if hotload:\n            if not (hasattr(module, \"vram_management_enabled\") and getattr(module, \"vram_management_enabled\")):\n                raise ValueError(\"VRAM Management is not enabled. LoRA hotloading is not supported.\")\n            updated_num = 0\n            for _, module in module.named_modules():\n                if isinstance(module, AutoWrappedLinear):\n                    name = module.name\n                    lora_a_name = f'{name}.lora_A.weight'\n                    lora_b_name = f'{name}.lora_B.weight'\n                    if lora_a_name in lora and lora_b_name in lora:\n                        updated_num += 1\n                        module.lora_A_weights.append(lora[lora_a_name] * alpha)\n                        module.lora_B_weights.append(lora[lora_b_name])\n            if verbose >= 1:\n                print(f\"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.\")\n        else:\n            lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)\n            \n            \n    def clear_lora(self, verbose=1):\n        cleared_num = 0\n        for name, module in self.named_modules():\n            if isinstance(module, AutoWrappedLinear):\n                if hasattr(module, \"lora_A_weights\"):\n                    if len(module.lora_A_weights) > 0:\n                        cleared_num += 1\n                    module.lora_A_weights.clear()\n                if hasattr(module, \"lora_B_weights\"):\n                    module.lora_B_weights.clear()\n        if verbose >= 1:\n            print(f\"{cleared_num} LoRA layers are cleared.\")\n        \n    \n    def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):\n        model_pool = ModelPool()\n        for model_config in model_configs:\n            model_config.download_if_necessary()\n            vram_config = model_config.vram_config()\n            vram_config[\"computation_dtype\"] = vram_config[\"computation_dtype\"] or self.torch_dtype\n            vram_config[\"computation_device\"] = vram_config[\"computation_device\"] or self.device\n            model_pool.auto_load_model(\n                model_config.path,\n                vram_config=vram_config,\n                vram_limit=vram_limit,\n                clear_parameters=model_config.clear_parameters,\n                state_dict=model_config.state_dict,\n            )\n        return model_pool\n    \n    \n    def check_vram_management_state(self):\n        vram_management_enabled = False\n        for module in self.children():\n            if hasattr(module, \"vram_management_enabled\") and getattr(module, \"vram_management_enabled\"):\n                vram_management_enabled = True\n        return vram_management_enabled\n    \n    \n    def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):\n        if inputs_shared.get(\"positive_only_lora\", None) is not None:\n            self.clear_lora(verbose=0)\n            self.load_lora(self.dit, state_dict=inputs_shared[\"positive_only_lora\"], verbose=0)\n        noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)\n        if cfg_scale != 1.0:\n            if inputs_shared.get(\"positive_only_lora\", None) is not None:\n                self.clear_lora(verbose=0)\n            noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)\n            if isinstance(noise_pred_posi, tuple):\n                # Separately handling different output types of latents, eg. video and audio latents.\n                noise_pred = tuple(\n                    n_nega + cfg_scale * (n_posi - n_nega)\n                    for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)\n                )\n            else:\n                noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)\n        else:\n            noise_pred = noise_pred_posi\n        return noise_pred\n\n\nclass PipelineUnitGraph:\n    def __init__(self):\n        pass\n    \n    def build_edges(self, units: list[PipelineUnit]):\n        # Establish dependencies between units\n        # to search for subsequent related computation units.\n        last_compute_unit_id = {}\n        edges = []\n        for unit_id, unit in enumerate(units):\n            for input_param in unit.fetch_input_params():\n                if input_param in last_compute_unit_id:\n                    edges.append((last_compute_unit_id[input_param], unit_id))\n            for output_param in unit.fetch_output_params():\n                last_compute_unit_id[output_param] = unit_id\n        return edges\n    \n    def build_chains(self, units: list[PipelineUnit]):\n        # Establish updating chains for each variable\n        # to track their computation process.\n        params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])\n        params = sorted(list(set(params)))\n        chains = {param: [] for param in params}\n        for unit_id, unit in enumerate(units):\n            for param in unit.fetch_output_params():\n                chains[param].append(unit_id)\n        return chains\n    \n    def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):\n        # Search for units that directly participate in the model's computation.\n        related_unit_ids = []\n        for unit_id, unit in enumerate(units):\n            for model_name in model_names:\n                if unit.onload_model_names is not None and model_name in unit.onload_model_names:\n                    related_unit_ids.append(unit_id)\n                    break\n        return related_unit_ids\n    \n    def search_related_unit_ids(self, edges, start_unit_ids, direction=\"target\"):\n        # Search for subsequent related computation units.\n        related_unit_ids = [unit_id for unit_id in start_unit_ids]\n        while True:\n            neighbors = []\n            for source, target in edges:\n                if direction == \"target\" and source in related_unit_ids and target not in related_unit_ids:\n                    neighbors.append(target)\n                elif direction == \"source\" and source not in related_unit_ids and target in related_unit_ids:\n                    neighbors.append(source)\n            neighbors = sorted(list(set(neighbors)))\n            if len(neighbors) == 0:\n                break\n            else:\n                related_unit_ids.extend(neighbors)\n        related_unit_ids = sorted(list(set(related_unit_ids)))\n        return related_unit_ids\n    \n    def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):\n        # If the input parameters of this subgraph are updated outside the subgraph,\n        # search for the units where these updates occur.\n        first_compute_unit_id = {}\n        for unit_id in related_unit_ids:\n            for param in units[unit_id].fetch_input_params():\n                if param not in first_compute_unit_id:\n                    first_compute_unit_id[param] = unit_id\n        updating_unit_ids = []\n        for param in first_compute_unit_id:\n            unit_id = first_compute_unit_id[param]\n            chain = chains[param]\n            if unit_id in chain and chain.index(unit_id) != len(chain) - 1:\n                for unit_id_ in chain[chain.index(unit_id) + 1:]:\n                    if unit_id_ not in related_unit_ids:\n                        updating_unit_ids.append(unit_id_)\n        related_unit_ids.extend(updating_unit_ids)\n        related_unit_ids = sorted(list(set(related_unit_ids)))\n        return related_unit_ids\n    \n    def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):\n        # Split the computation graph,\n        # separating all model-related computations.\n        related_unit_ids = self.search_direct_unit_ids(units, model_names)\n        edges = self.build_edges(units)\n        chains = self.build_chains(units)\n        while True:\n            num_related_unit_ids = len(related_unit_ids)\n            related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, \"target\")\n            related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)\n            if len(related_unit_ids) == num_related_unit_ids:\n                break\n            else:\n                num_related_unit_ids = len(related_unit_ids)\n        related_units = [units[i] for i in related_unit_ids]\n        unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]\n        return related_units, unrelated_units\n\n\nclass PipelineUnitRunner:\n    def __init__(self):\n        pass\n\n    def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:\n        if unit.take_over:\n            # Let the pipeline unit take over this function.\n            inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)\n        elif unit.seperate_cfg:\n            # Positive side\n            processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}\n            if unit.input_params is not None:\n                for name in unit.input_params:\n                    processor_inputs[name] = inputs_shared.get(name)\n            processor_outputs = unit.process(pipe, **processor_inputs)\n            inputs_posi.update(processor_outputs)\n            # Negative side\n            if inputs_shared[\"cfg_scale\"] != 1:\n                processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}\n                if unit.input_params is not None:\n                    for name in unit.input_params:\n                        processor_inputs[name] = inputs_shared.get(name)\n                processor_outputs = unit.process(pipe, **processor_inputs)\n                inputs_nega.update(processor_outputs)\n            else:\n                inputs_nega.update(processor_outputs)\n        else:\n            processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}\n            processor_outputs = unit.process(pipe, **processor_inputs)\n            inputs_shared.update(processor_outputs)\n        return inputs_shared, inputs_posi, inputs_nega\n"
  },
  {
    "path": "diffsynth/diffusion/flow_match.py",
    "content": "import torch, math\nfrom typing_extensions import Literal\n\n\nclass FlowMatchScheduler():\n\n    def __init__(self, template: Literal[\"FLUX.1\", \"Wan\", \"Qwen-Image\", \"FLUX.2\", \"Z-Image\", \"LTX-2\", \"Qwen-Image-Lightning\"] = \"FLUX.1\"):\n        self.set_timesteps_fn = {\n            \"FLUX.1\": FlowMatchScheduler.set_timesteps_flux,\n            \"Wan\": FlowMatchScheduler.set_timesteps_wan,\n            \"Qwen-Image\": FlowMatchScheduler.set_timesteps_qwen_image,\n            \"FLUX.2\": FlowMatchScheduler.set_timesteps_flux2,\n            \"Z-Image\": FlowMatchScheduler.set_timesteps_z_image,\n            \"LTX-2\": FlowMatchScheduler.set_timesteps_ltx2,\n            \"Qwen-Image-Lightning\": FlowMatchScheduler.set_timesteps_qwen_image_lightning,\n        }.get(template, FlowMatchScheduler.set_timesteps_flux)\n        self.num_train_timesteps = 1000\n\n    @staticmethod\n    def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):\n        sigma_min = 0.003/1.002\n        sigma_max = 1.0\n        shift = 3 if shift is None else shift\n        num_train_timesteps = 1000\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)\n        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n    \n    @staticmethod\n    def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):\n        sigma_min = 0.0\n        sigma_max = 1.0\n        shift = 5 if shift is None else shift\n        num_train_timesteps = 1000\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]\n        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n    \n    @staticmethod\n    def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):\n        m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n        b = base_shift - m * base_seq_len\n        mu = image_seq_len * m + b\n        return mu\n    \n    @staticmethod\n    def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):\n        sigma_min = 0.0\n        sigma_max = 1.0\n        num_train_timesteps = 1000\n        shift_terminal = 0.02\n        # Sigmas\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]\n        # Mu\n        if exponential_shift_mu is not None:\n            mu = exponential_shift_mu\n        elif dynamic_shift_len is not None:\n            mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)\n        else:\n            mu = 0.8\n        sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))\n        # Shift terminal\n        one_minus_z = 1 - sigmas\n        scale_factor = one_minus_z[-1] / (1 - shift_terminal)\n        sigmas = 1 - (one_minus_z / scale_factor)\n        # Timesteps\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n    \n    @staticmethod\n    def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):\n        sigma_min = 0.0\n        sigma_max = 1.0\n        num_train_timesteps = 1000\n        base_shift = math.log(3)\n        max_shift = math.log(3)\n        # Sigmas\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]\n        # Mu\n        if exponential_shift_mu is not None:\n            mu = exponential_shift_mu\n        elif dynamic_shift_len is not None:\n            mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)\n        else:\n            mu = 0.8\n        sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))\n        # Timesteps\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n    \n    @staticmethod\n    def compute_empirical_mu(image_seq_len, num_steps):\n        a1, b1 = 8.73809524e-05, 1.89833333\n        a2, b2 = 0.00016927, 0.45666666\n\n        if image_seq_len > 4300:\n            mu = a2 * image_seq_len + b2\n            return float(mu)\n\n        m_200 = a2 * image_seq_len + b2\n        m_10 = a1 * image_seq_len + b1\n\n        a = (m_200 - m_10) / 190.0\n        b = m_200 - 200.0 * a\n        mu = a * num_steps + b\n\n        return float(mu)\n    \n    @staticmethod\n    def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):\n        sigma_min = 1 / num_inference_steps\n        sigma_max = 1.0\n        num_train_timesteps = 1000\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)\n        if dynamic_shift_len is None:\n            # If you ask me why I set mu=0.8,\n            # I can only say that it yields better training results.\n            mu = 0.8\n        else:\n            mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)\n        sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n\n    @staticmethod\n    def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):\n        sigma_min = 0.0\n        sigma_max = 1.0\n        shift = 3 if shift is None else shift\n        num_train_timesteps = 1000\n        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]\n        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n        timesteps = sigmas * num_train_timesteps\n        if target_timesteps is not None:\n            target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)\n            for timestep in target_timesteps:\n                timestep_id = torch.argmin((timesteps - timestep).abs())\n                timesteps[timestep_id] = timestep\n        return sigmas, timesteps\n\n    @staticmethod\n    def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):\n        num_train_timesteps = 1000\n        if special_case == \"stage2\":\n            sigmas = torch.Tensor([0.909375, 0.725, 0.421875])\n        elif special_case == \"ditilled_stage1\":\n            sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])\n        else:\n            dynamic_shift_len = dynamic_shift_len or 4096\n            sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(\n                image_seq_len=dynamic_shift_len,\n                base_seq_len=1024,\n                max_seq_len=4096,\n                base_shift=0.95,\n                max_shift=2.05,\n            )\n            sigma_min = 0.0\n            sigma_max = 1.0\n            sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength\n            sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]\n            sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))\n            # Shift terminal\n            one_minus_z = 1.0 - sigmas\n            scale_factor = one_minus_z[-1] / (1 - terminal)\n            sigmas = 1.0 - (one_minus_z / scale_factor)\n        timesteps = sigmas * num_train_timesteps\n        return sigmas, timesteps\n\n    def set_training_weight(self):\n        steps = 1000\n        x = self.timesteps\n        y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)\n        y_shifted = y - y.min()\n        bsmntw_weighing = y_shifted * (steps / y_shifted.sum())\n        if len(self.timesteps) != 1000:\n            # This is an empirical formula.\n            bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)\n            bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]\n        self.linear_timesteps_weights = bsmntw_weighing\n        \n    def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):\n        self.sigmas, self.timesteps = self.set_timesteps_fn(\n            num_inference_steps=num_inference_steps,\n            denoising_strength=denoising_strength,\n            **kwargs,\n        )\n        if training:\n            self.set_training_weight()\n            self.training = True\n        else:\n            self.training = False\n\n    def step(self, model_output, timestep, sample, to_final=False, **kwargs):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        if to_final or timestep_id + 1 >= len(self.timesteps):\n            sigma_ = 0\n        else:\n            sigma_ = self.sigmas[timestep_id + 1]\n        prev_sample = sample + model_output * (sigma_ - sigma)\n        return prev_sample\n    \n    def return_to_timestep(self, timestep, sample, sample_stablized):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        model_output = (sample - sample_stablized) / sigma\n        return model_output\n    \n    def add_noise(self, original_samples, noise, timestep):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        sample = (1 - sigma) * original_samples + sigma * noise\n        return sample\n    \n    def training_target(self, sample, noise, timestep):\n        target = noise - sample\n        return target\n    \n    def training_weight(self, timestep):\n        timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())\n        weights = self.linear_timesteps_weights[timestep_id]\n        return weights\n"
  },
  {
    "path": "diffsynth/diffusion/logger.py",
    "content": "import os, torch\nfrom accelerate import Accelerator\n\n\nclass ModelLogger:\n    def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):\n        self.output_path = output_path\n        self.remove_prefix_in_ckpt = remove_prefix_in_ckpt\n        self.state_dict_converter = state_dict_converter\n        self.num_steps = 0\n\n\n    def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):\n        self.num_steps += 1\n        if save_steps is not None and self.num_steps % save_steps == 0:\n            self.save_model(accelerator, model, f\"step-{self.num_steps}.safetensors\")\n\n\n    def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):\n        accelerator.wait_for_everyone()\n        state_dict = accelerator.get_state_dict(model)\n        if accelerator.is_main_process:\n            state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)\n            state_dict = self.state_dict_converter(state_dict)\n            os.makedirs(self.output_path, exist_ok=True)\n            path = os.path.join(self.output_path, f\"epoch-{epoch_id}.safetensors\")\n            accelerator.save(state_dict, path, safe_serialization=True)\n\n\n    def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):\n        if save_steps is not None and self.num_steps % save_steps != 0:\n            self.save_model(accelerator, model, f\"step-{self.num_steps}.safetensors\")\n\n\n    def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):\n        accelerator.wait_for_everyone()\n        state_dict = accelerator.get_state_dict(model)\n        if accelerator.is_main_process:\n            state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)\n            state_dict = self.state_dict_converter(state_dict)\n            os.makedirs(self.output_path, exist_ok=True)\n            path = os.path.join(self.output_path, file_name)\n            accelerator.save(state_dict, path, safe_serialization=True)\n"
  },
  {
    "path": "diffsynth/diffusion/loss.py",
    "content": "from .base_pipeline import BasePipeline\nimport torch\n\n\ndef FlowMatchSFTLoss(pipe: BasePipeline, **inputs):\n    max_timestep_boundary = int(inputs.get(\"max_timestep_boundary\", 1) * len(pipe.scheduler.timesteps))\n    min_timestep_boundary = int(inputs.get(\"min_timestep_boundary\", 0) * len(pipe.scheduler.timesteps))\n\n    timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))\n    timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)\n    \n    noise = torch.randn_like(inputs[\"input_latents\"])\n    inputs[\"latents\"] = pipe.scheduler.add_noise(inputs[\"input_latents\"], noise, timestep)\n    training_target = pipe.scheduler.training_target(inputs[\"input_latents\"], noise, timestep)\n    \n    if \"first_frame_latents\" in inputs:\n        inputs[\"latents\"][:, :, 0:1] = inputs[\"first_frame_latents\"]\n    \n    models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n    noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)\n    \n    if \"first_frame_latents\" in inputs:\n        noise_pred = noise_pred[:, :, 1:]\n        training_target = training_target[:, :, 1:]\n    \n    loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())\n    loss = loss * pipe.scheduler.training_weight(timestep)\n    return loss\n\n\ndef FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):\n    max_timestep_boundary = int(inputs.get(\"max_timestep_boundary\", 1) * len(pipe.scheduler.timesteps))\n    min_timestep_boundary = int(inputs.get(\"min_timestep_boundary\", 0) * len(pipe.scheduler.timesteps))\n\n    timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))\n    timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)\n    \n    # video\n    noise = torch.randn_like(inputs[\"input_latents\"])\n    inputs[\"video_latents\"] = pipe.scheduler.add_noise(inputs[\"input_latents\"], noise, timestep)\n    training_target = pipe.scheduler.training_target(inputs[\"input_latents\"], noise, timestep)\n    \n    # audio\n    if inputs.get(\"audio_input_latents\") is not None:\n        audio_noise = torch.randn_like(inputs[\"audio_input_latents\"])\n        inputs[\"audio_latents\"] = pipe.scheduler.add_noise(inputs[\"audio_input_latents\"], audio_noise, timestep)\n        training_target_audio = pipe.scheduler.training_target(inputs[\"audio_input_latents\"], audio_noise, timestep)\n\n    models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n    noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)\n\n    loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())\n    loss = loss * pipe.scheduler.training_weight(timestep)\n    if inputs.get(\"audio_input_latents\") is not None:\n        loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())\n        loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)\n        loss = loss + loss_audio\n    return loss\n\n\ndef DirectDistillLoss(pipe: BasePipeline, **inputs):\n    pipe.scheduler.set_timesteps(inputs[\"num_inference_steps\"])\n    pipe.scheduler.training = True\n    models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n    for progress_id, timestep in enumerate(pipe.scheduler.timesteps):\n        timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)\n        noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)\n        inputs[\"latents\"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)\n    loss = torch.nn.functional.mse_loss(inputs[\"latents\"].float(), inputs[\"input_latents\"].float())\n    return loss\n\n\nclass TrajectoryImitationLoss(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.initialized = False\n    \n    def initialize(self, device):\n        import lpips # TODO: remove it\n        self.loss_fn = lpips.LPIPS(net='alex').to(device)\n        self.initialized = True\n\n    def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):\n        trajectory = [inputs_shared[\"latents\"].clone()]\n\n        pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)\n        models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n        for progress_id, timestep in enumerate(pipe.scheduler.timesteps):\n            timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)\n            noise_pred = pipe.cfg_guided_model_fn(\n                pipe.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)\n\n            trajectory.append(inputs_shared[\"latents\"].clone())\n        return pipe.scheduler.timesteps, trajectory\n    \n    def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):\n        loss = 0\n        pipe.scheduler.set_timesteps(num_inference_steps, training=True)\n        models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n        for progress_id, timestep in enumerate(pipe.scheduler.timesteps):\n            timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)\n\n            progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())\n            inputs_shared[\"latents\"] = trajectory_teacher[progress_id_teacher]\n\n            noise_pred = pipe.cfg_guided_model_fn(\n                pipe.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n\n            sigma = pipe.scheduler.sigmas[progress_id]\n            sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]\n            if progress_id + 1 >= len(pipe.scheduler.timesteps):\n                latents_ = trajectory_teacher[-1]\n            else:\n                progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())\n                latents_ = trajectory_teacher[progress_id_teacher]\n            \n            denom = sigma_ - sigma\n            denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)\n            target = (latents_ - inputs_shared[\"latents\"]) / denom\n            loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)\n        return loss\n    \n    def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):\n        inputs_shared[\"latents\"] = trajectory_teacher[0]\n        pipe.scheduler.set_timesteps(num_inference_steps)\n        models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}\n        for progress_id, timestep in enumerate(pipe.scheduler.timesteps):\n            timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)\n            noise_pred = pipe.cfg_guided_model_fn(\n                pipe.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)\n\n        image_pred = pipe.vae_decoder(inputs_shared[\"latents\"])\n        image_real = pipe.vae_decoder(trajectory_teacher[-1])\n        loss = self.loss_fn(image_pred.float(), image_real.float())\n        return loss\n\n    def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):\n        if not self.initialized:\n            self.initialize(pipe.device)\n        with torch.no_grad():\n            pipe.scheduler.set_timesteps(8)\n            timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared[\"teacher\"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)\n            timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)\n        loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)\n        loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)\n        loss = loss_1 + loss_2\n        return loss\n"
  },
  {
    "path": "diffsynth/diffusion/parsers.py",
    "content": "import argparse\n\n\ndef add_dataset_base_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--dataset_base_path\", type=str, default=\"\", required=True, help=\"Base path of the dataset.\")\n    parser.add_argument(\"--dataset_metadata_path\", type=str, default=None, help=\"Path to the metadata file of the dataset.\")\n    parser.add_argument(\"--dataset_repeat\", type=int, default=1, help=\"Number of times to repeat the dataset per epoch.\")\n    parser.add_argument(\"--dataset_num_workers\", type=int, default=0, help=\"Number of workers for data loading.\")\n    parser.add_argument(\"--data_file_keys\", type=str, default=\"image,video\", help=\"Data file keys in the metadata. Comma-separated.\")\n    return parser\n\ndef add_image_size_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--height\", type=int, default=None, help=\"Height of images. Leave `height` and `width` empty to enable dynamic resolution.\")\n    parser.add_argument(\"--width\", type=int, default=None, help=\"Width of images. Leave `height` and `width` empty to enable dynamic resolution.\")\n    parser.add_argument(\"--max_pixels\", type=int, default=1024*1024, help=\"Maximum number of pixels per frame, used for dynamic resolution.\")\n    return parser\n\ndef add_video_size_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--height\", type=int, default=None, help=\"Height of images. Leave `height` and `width` empty to enable dynamic resolution.\")\n    parser.add_argument(\"--width\", type=int, default=None, help=\"Width of images. Leave `height` and `width` empty to enable dynamic resolution.\")\n    parser.add_argument(\"--max_pixels\", type=int, default=1024*1024, help=\"Maximum number of pixels per frame, used for dynamic resolution.\")\n    parser.add_argument(\"--num_frames\", type=int, default=81, help=\"Number of frames per video. Frames are sampled from the video prefix.\")\n    return parser\n\ndef add_model_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--model_paths\", type=str, default=None, help=\"Paths to load models. In JSON format.\")\n    parser.add_argument(\"--model_id_with_origin_paths\", type=str, default=None, help=\"Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.\")\n    parser.add_argument(\"--extra_inputs\", default=None, help=\"Additional model inputs, comma-separated.\")\n    parser.add_argument(\"--fp8_models\", default=None, help=\"Models with FP8 precision, comma-separated.\")\n    parser.add_argument(\"--offload_models\", default=None, help=\"Models with offload, comma-separated. Only used in splited training.\")\n    return parser\n\ndef add_training_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"Learning rate.\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1, help=\"Number of epochs.\")\n    parser.add_argument(\"--trainable_models\", type=str, default=None, help=\"Models to train, e.g., dit, vae, text_encoder.\")\n    parser.add_argument(\"--find_unused_parameters\", default=False, action=\"store_true\", help=\"Whether to find unused parameters in DDP.\")\n    parser.add_argument(\"--weight_decay\", type=float, default=0.01, help=\"Weight decay.\")\n    parser.add_argument(\"--task\", type=str, default=\"sft\", required=False, help=\"Task type.\")\n    return parser\n\ndef add_output_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--output_path\", type=str, default=\"./models\", help=\"Output save path.\")\n    parser.add_argument(\"--remove_prefix_in_ckpt\", type=str, default=\"pipe.dit.\", help=\"Remove prefix in ckpt.\")\n    parser.add_argument(\"--save_steps\", type=int, default=None, help=\"Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.\")\n    return parser\n\ndef add_lora_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--lora_base_model\", type=str, default=None, help=\"Which model LoRA is added to.\")\n    parser.add_argument(\"--lora_target_modules\", type=str, default=\"q,k,v,o,ffn.0,ffn.2\", help=\"Which layers LoRA is added to.\")\n    parser.add_argument(\"--lora_rank\", type=int, default=32, help=\"Rank of LoRA.\")\n    parser.add_argument(\"--lora_checkpoint\", type=str, default=None, help=\"Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.\")\n    parser.add_argument(\"--preset_lora_path\", type=str, default=None, help=\"Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.\")\n    parser.add_argument(\"--preset_lora_model\", type=str, default=None, help=\"Which model the preset LoRA is fused to.\")\n    return parser\n\ndef add_gradient_config(parser: argparse.ArgumentParser):\n    parser.add_argument(\"--use_gradient_checkpointing\", default=False, action=\"store_true\", help=\"Whether to use gradient checkpointing.\")\n    parser.add_argument(\"--use_gradient_checkpointing_offload\", default=False, action=\"store_true\", help=\"Whether to offload gradient checkpointing to CPU memory.\")\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=1, help=\"Gradient accumulation steps.\")\n    return parser\n\ndef add_general_config(parser: argparse.ArgumentParser):\n    parser = add_dataset_base_config(parser)\n    parser = add_model_config(parser)\n    parser = add_training_config(parser)\n    parser = add_output_config(parser)\n    parser = add_lora_config(parser)\n    parser = add_gradient_config(parser)\n    return parser\n"
  },
  {
    "path": "diffsynth/diffusion/runner.py",
    "content": "import os, torch\nfrom tqdm import tqdm\nfrom accelerate import Accelerator\nfrom .training_module import DiffusionTrainingModule\nfrom .logger import ModelLogger\n\n\ndef launch_training_task(\n    accelerator: Accelerator,\n    dataset: torch.utils.data.Dataset,\n    model: DiffusionTrainingModule,\n    model_logger: ModelLogger,\n    learning_rate: float = 1e-5,\n    weight_decay: float = 1e-2,\n    num_workers: int = 1,\n    save_steps: int = None,\n    num_epochs: int = 1,\n    args = None,\n):\n    if args is not None:\n        learning_rate = args.learning_rate\n        weight_decay = args.weight_decay\n        num_workers = args.dataset_num_workers\n        save_steps = args.save_steps\n        num_epochs = args.num_epochs\n    \n    optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)\n    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)\n    dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)\n    model.to(device=accelerator.device)\n    model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)\n    initialize_deepspeed_gradient_checkpointing(accelerator)\n    for epoch_id in range(num_epochs):\n        for data in tqdm(dataloader):\n            with accelerator.accumulate(model):\n                optimizer.zero_grad()\n                if dataset.load_from_cache:\n                    loss = model({}, inputs=data)\n                else:\n                    loss = model(data)\n                accelerator.backward(loss)\n                optimizer.step()\n                model_logger.on_step_end(accelerator, model, save_steps, loss=loss)\n                scheduler.step()\n        if save_steps is None:\n            model_logger.on_epoch_end(accelerator, model, epoch_id)\n    model_logger.on_training_end(accelerator, model, save_steps)\n\n\ndef launch_data_process_task(\n    accelerator: Accelerator,\n    dataset: torch.utils.data.Dataset,\n    model: DiffusionTrainingModule,\n    model_logger: ModelLogger,\n    num_workers: int = 8,\n    args = None,\n):\n    if args is not None:\n        num_workers = args.dataset_num_workers\n        \n    dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)\n    model.to(device=accelerator.device)\n    model, dataloader = accelerator.prepare(model, dataloader)\n    \n    for data_id, data in enumerate(tqdm(dataloader)):\n        with accelerator.accumulate(model):\n            with torch.no_grad():\n                folder = os.path.join(model_logger.output_path, str(accelerator.process_index))\n                os.makedirs(folder, exist_ok=True)\n                save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f\"{data_id}.pth\")\n                data = model(data)\n                torch.save(data, save_path)\n\n\ndef initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):\n    if getattr(accelerator.state, \"deepspeed_plugin\", None) is not None:\n        ds_config = accelerator.state.deepspeed_plugin.deepspeed_config\n        if \"activation_checkpointing\" in ds_config:\n            import deepspeed\n            act_config = ds_config[\"activation_checkpointing\"]\n            deepspeed.checkpointing.configure(\n                mpu_=None, \n                partition_activations=act_config.get(\"partition_activations\", False),\n                checkpoint_in_cpu=act_config.get(\"cpu_checkpointing\", False),\n                contiguous_checkpointing=act_config.get(\"contiguous_memory_optimization\", False)\n            )\n        else:\n            print(\"Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.\")\n"
  },
  {
    "path": "diffsynth/diffusion/training_module.py",
    "content": "import torch, json, os, inspect\nfrom ..core import ModelConfig, load_state_dict\nfrom ..utils.controlnet import ControlNetInput\nfrom .base_pipeline import PipelineUnit\nfrom peft import LoraConfig, inject_adapter_in_model\n\n\nclass GeneralUnit_RemoveCache(PipelineUnit):\n    def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):\n        super().__init__(take_over=True)\n        self.required_params = required_params\n        self.force_remove_params_shared = force_remove_params_shared\n        self.force_remove_params_posi = force_remove_params_posi\n        self.force_remove_params_nega = force_remove_params_nega\n\n    def process_params(self, inputs, required_params, force_remove_params):\n        inputs_ = {}\n        for name, param in inputs.items():\n            if name in required_params and name not in force_remove_params:\n                inputs_[name] = param\n        return inputs_\n\n    def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):\n        inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)\n        inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)\n        inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass DiffusionTrainingModule(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        \n        \n    def to(self, *args, **kwargs):\n        for name, model in self.named_children():\n            model.to(*args, **kwargs)\n        return self\n        \n        \n    def trainable_modules(self):\n        trainable_modules = filter(lambda p: p.requires_grad, self.parameters())\n        return trainable_modules\n    \n    \n    def trainable_param_names(self):\n        trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))\n        trainable_param_names = set([named_param[0] for named_param in trainable_param_names])\n        return trainable_param_names\n    \n    \n    def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):\n        if lora_alpha is None:\n            lora_alpha = lora_rank\n        if isinstance(target_modules, list) and len(target_modules) == 1:\n            target_modules = target_modules[0]\n        lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)\n        model = inject_adapter_in_model(lora_config, model)\n        if upcast_dtype is not None:\n            for param in model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(upcast_dtype)\n        return model\n\n\n    def mapping_lora_state_dict(self, state_dict):\n        new_state_dict = {}\n        for key, value in state_dict.items():\n            if \"lora_A.weight\" in key or \"lora_B.weight\" in key:\n                new_key = key.replace(\"lora_A.weight\", \"lora_A.default.weight\").replace(\"lora_B.weight\", \"lora_B.default.weight\")\n                new_state_dict[new_key] = value\n            elif \"lora_A.default.weight\" in key or \"lora_B.default.weight\" in key:\n                new_state_dict[key] = value\n        return new_state_dict\n\n\n    def export_trainable_state_dict(self, state_dict, remove_prefix=None):\n        trainable_param_names = self.trainable_param_names()\n        state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}\n        if remove_prefix is not None:\n            state_dict_ = {}\n            for name, param in state_dict.items():\n                if name.startswith(remove_prefix):\n                    name = name[len(remove_prefix):]\n                state_dict_[name] = param\n            state_dict = state_dict_\n        return state_dict\n    \n    \n    def transfer_data_to_device(self, data, device, torch_float_dtype=None):\n        if data is None:\n            return data\n        elif isinstance(data, torch.Tensor):\n            data = data.to(device)\n            if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:\n                data = data.to(torch_float_dtype)\n            return data\n        elif isinstance(data, tuple):\n            data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)\n            return data\n        elif isinstance(data, list):\n            data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)\n            return data\n        elif isinstance(data, dict):\n            data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}\n            return data\n        else:\n            return data\n    \n    def parse_vram_config(self, fp8=False, offload=False, device=\"cpu\"):\n        if fp8:\n            return {\n                \"offload_dtype\": torch.float8_e4m3fn,\n                \"offload_device\": device,\n                \"onload_dtype\": torch.float8_e4m3fn,\n                \"onload_device\": device,\n                \"preparing_dtype\": torch.float8_e4m3fn,\n                \"preparing_device\": device,\n                \"computation_dtype\": torch.bfloat16,\n                \"computation_device\": device,\n            }\n        elif offload:\n            return {\n                \"offload_dtype\": \"disk\",\n                \"offload_device\": \"disk\",\n                \"onload_dtype\": \"disk\",\n                \"onload_device\": \"disk\",\n                \"preparing_dtype\": torch.bfloat16,\n                \"preparing_device\": device,\n                \"computation_dtype\": torch.bfloat16,\n                \"computation_device\": device,\n                \"clear_parameters\": True,\n            }\n        else:\n            return {}\n    \n    def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device=\"cpu\"):\n        fp8_models = [] if fp8_models is None else fp8_models.split(\",\")\n        offload_models = [] if offload_models is None else offload_models.split(\",\")\n        model_configs = []\n        if model_paths is not None:\n            model_paths = json.loads(model_paths)\n            for path in model_paths:\n                vram_config = self.parse_vram_config(\n                    fp8=path in fp8_models,\n                    offload=path in offload_models,\n                    device=device\n                )\n                model_configs.append(ModelConfig(path=path, **vram_config))\n        if model_id_with_origin_paths is not None:\n            model_id_with_origin_paths = model_id_with_origin_paths.split(\",\")\n            for model_id_with_origin_path in model_id_with_origin_paths:\n                vram_config = self.parse_vram_config(\n                    fp8=model_id_with_origin_path in fp8_models,\n                    offload=model_id_with_origin_path in offload_models,\n                    device=device\n                )\n                config = self.parse_path_or_model_id(model_id_with_origin_path)\n                model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))\n        return model_configs\n    \n\n    def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):\n        if model_id_with_origin_path is None:\n            return default_value\n        elif os.path.exists(model_id_with_origin_path):\n            return ModelConfig(path=model_id_with_origin_path)\n        else:\n            if \":\" not in model_id_with_origin_path:\n                raise ValueError(f\"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.\")\n            split_id = model_id_with_origin_path.rfind(\":\")\n            model_id = model_id_with_origin_path[:split_id]\n            origin_file_pattern = model_id_with_origin_path[split_id + 1:]\n            return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)\n\n\n    def auto_detect_lora_target_modules(\n        self,\n        model: torch.nn.Module,\n        search_for_linear=False,\n        linear_detector=lambda x: min(x.weight.shape) >= 512,\n        block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,\n        name_prefix=\"\",\n    ):\n        lora_target_modules = []\n        if search_for_linear:\n            for name, module in model.named_modules():\n                module_name = name_prefix + [\"\", \".\"][name_prefix != \"\"] + name\n                if isinstance(module, torch.nn.Linear) and linear_detector(module):\n                    lora_target_modules.append(module_name)\n        else:\n            for name, module in model.named_children():\n                module_name = name_prefix + [\"\", \".\"][name_prefix != \"\"] + name\n                lora_target_modules += self.auto_detect_lora_target_modules(\n                    module,\n                    search_for_linear=block_list_detector(module),\n                    linear_detector=linear_detector,\n                    block_list_detector=block_list_detector,\n                    name_prefix=module_name,\n                )\n        return lora_target_modules\n    \n\n    def parse_lora_target_modules(self, model, lora_target_modules):\n        if lora_target_modules == \"\":\n            print(\"No LoRA target modules specified. The framework will automatically search for them.\")\n            lora_target_modules = self.auto_detect_lora_target_modules(model)\n            print(f\"LoRA will be patched at {lora_target_modules}.\")\n        else:\n            lora_target_modules = lora_target_modules.split(\",\")\n        return lora_target_modules\n\n\n    def switch_pipe_to_training_mode(\n        self,\n        pipe,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        task=\"sft\",\n    ):\n        # Scheduler\n        pipe.scheduler.set_timesteps(1000, training=True)\n        \n        # Freeze untrainable models\n        pipe.freeze_except([] if trainable_models is None else trainable_models.split(\",\"))\n        \n        # Preset LoRA\n        if preset_lora_path is not None:\n            pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)\n        \n        # FP8\n        # FP8 relies on a model-specific memory management scheme.\n        # It is delegated to the subclass.\n        \n        # Add LoRA to the base models\n        if lora_base_model is not None and not task.endswith(\":data_process\"):\n            if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:\n                print(f\"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.\")\n                return\n            model = self.add_lora_to_model(\n                getattr(pipe, lora_base_model),\n                target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),\n                lora_rank=lora_rank,\n                upcast_dtype=pipe.torch_dtype,\n            )\n            if lora_checkpoint is not None:\n                state_dict = load_state_dict(lora_checkpoint)\n                state_dict = self.mapping_lora_state_dict(state_dict)\n                load_result = model.load_state_dict(state_dict, strict=False)\n                print(f\"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys\")\n                if len(load_result[1]) > 0:\n                    print(f\"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}\")\n            setattr(pipe, lora_base_model, model)\n\n\n    def split_pipeline_units(\n        self, task, pipe,\n        trainable_models=None, lora_base_model=None,\n        # TODO: set `remove_unnecessary_params` to `True` by default\n        remove_unnecessary_params=False,\n        # TODO: move `loss_required_params` to `loss.py`\n        loss_required_params=(\"input_latents\", \"max_timestep_boundary\", \"min_timestep_boundary\", \"first_frame_latents\", \"video_latents\", \"audio_input_latents\", \"num_inference_steps\"),\n        force_remove_params_shared=tuple(),\n        force_remove_params_posi=tuple(),\n        force_remove_params_nega=tuple(),\n    ):\n        models_require_backward = []\n        if trainable_models is not None:\n            models_require_backward += trainable_models.split(\",\")\n        if lora_base_model is not None:\n            models_require_backward += [lora_base_model]\n        if task.endswith(\":data_process\"):\n            other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)\n            if remove_unnecessary_params:\n                required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]\n                for unit in other_units:\n                    required_params.extend(unit.fetch_input_params())\n                required_params = sorted(list(set(required_params)))\n                pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))\n        elif task.endswith(\":train\"):\n            pipe.units, _ = pipe.split_pipeline_units(models_require_backward)\n        return pipe\n    \n    def parse_extra_inputs(self, data, extra_inputs, inputs_shared):\n        controlnet_keys_map = (\n            (\"blockwise_controlnet_\", \"blockwise_controlnet_inputs\",),\n            (\"controlnet_\", \"controlnet_inputs\"),\n        )\n        controlnet_inputs = {}\n        for extra_input in extra_inputs:\n            for prefix, name in controlnet_keys_map:\n                if extra_input.startswith(prefix):\n                    if name not in controlnet_inputs:\n                        controlnet_inputs[name] = {}\n                    controlnet_inputs[name][extra_input.replace(prefix, \"\")] = data[extra_input]\n                    break\n            else:\n                inputs_shared[extra_input] = data[extra_input]\n        for name, params in controlnet_inputs.items():\n            inputs_shared[name] = [ControlNetInput(**params)]\n        return inputs_shared\n"
  },
  {
    "path": "diffsynth/models/anima_dit.py",
    "content": "# original code from: comfy/ldm/cosmos/predict2.py\n\nimport torch\nfrom torch import nn\nfrom einops import rearrange, repeat\nfrom einops.layers.torch import Rearrange\nimport logging\nfrom typing import Callable, Optional, Tuple, List\nimport math\nfrom torchvision import transforms\nfrom ..core.attention import attention_forward\nfrom ..core.gradient import gradient_checkpoint_forward\n\n\nclass VideoPositionEmb(nn.Module):\n    def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:\n        \"\"\"\n        It delegates the embedding generation to generate_embeddings function.\n        \"\"\"\n        B_T_H_W_C = x_B_T_H_W_C.shape\n        embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)\n\n        return embeddings\n\n    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):\n        raise NotImplementedError\n\n\ndef normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:\n    \"\"\"\n    Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.\n\n    Args:\n        x (torch.Tensor): The input tensor to normalize.\n        dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.\n        eps (float, optional): A small constant to ensure numerical stability during division.\n\n    Returns:\n        torch.Tensor: The normalized tensor.\n    \"\"\"\n    if dim is None:\n        dim = list(range(1, x.ndim))\n    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)\n    norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))\n    return x / norm.to(x.dtype)\n\n\nclass LearnablePosEmbAxis(VideoPositionEmb):\n    def __init__(\n        self,\n        *,  # enforce keyword arguments\n        interpolation: str,\n        model_channels: int,\n        len_h: int,\n        len_w: int,\n        len_t: int,\n        device=None,\n        dtype=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            interpolation (str): we curretly only support \"crop\", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.\n        \"\"\"\n        del kwargs  # unused\n        super().__init__()\n        self.interpolation = interpolation\n        assert self.interpolation in [\"crop\"], f\"Unknown interpolation method {self.interpolation}\"\n\n        self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))\n        self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))\n        self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))\n\n    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:\n        B, T, H, W, _ = B_T_H_W_C\n        if self.interpolation == \"crop\":\n            emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)\n            emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)\n            emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)\n            emb = (\n                repeat(emb_t_T, \"t d-> b t h w d\", b=B, h=H, w=W)\n                + repeat(emb_h_H, \"h d-> b t h w d\", b=B, t=T, w=W)\n                + repeat(emb_w_W, \"w d-> b t h w d\", b=B, t=T, h=H)\n            )\n            assert list(emb.shape)[:4] == [B, T, H, W], f\"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}\"\n        else:\n            raise ValueError(f\"Unknown interpolation method {self.interpolation}\")\n\n        return normalize(emb, dim=-1, eps=1e-6)\n\n\nclass VideoRopePosition3DEmb(VideoPositionEmb):\n    def __init__(\n        self,\n        *,  # enforce keyword arguments\n        head_dim: int,\n        len_h: int,\n        len_w: int,\n        len_t: int,\n        base_fps: int = 24,\n        h_extrapolation_ratio: float = 1.0,\n        w_extrapolation_ratio: float = 1.0,\n        t_extrapolation_ratio: float = 1.0,\n        enable_fps_modulation: bool = True,\n        device=None,\n        **kwargs,  # used for compatibility with other positional embeddings; unused in this class\n    ):\n        del kwargs\n        super().__init__()\n        self.base_fps = base_fps\n        self.max_h = len_h\n        self.max_w = len_w\n        self.enable_fps_modulation = enable_fps_modulation\n\n        dim = head_dim\n        dim_h = dim // 6 * 2\n        dim_w = dim_h\n        dim_t = dim - 2 * dim_h\n        assert dim == dim_h + dim_w + dim_t, f\"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}\"\n        self.register_buffer(\n            \"dim_spatial_range\",\n            torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,\n            persistent=False,\n        )\n        self.register_buffer(\n            \"dim_temporal_range\",\n            torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,\n            persistent=False,\n        )\n\n        self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))\n        self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))\n        self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))\n\n    def generate_embeddings(\n        self,\n        B_T_H_W_C: torch.Size,\n        fps: Optional[torch.Tensor] = None,\n        h_ntk_factor: Optional[float] = None,\n        w_ntk_factor: Optional[float] = None,\n        t_ntk_factor: Optional[float] = None,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        Generate embeddings for the given input size.\n\n        Args:\n            B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).\n            fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.\n            h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.\n            w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.\n            t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.\n\n        Returns:\n            Not specified in the original code snippet.\n        \"\"\"\n        h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor\n        w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor\n        t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor\n\n        h_theta = 10000.0 * h_ntk_factor\n        w_theta = 10000.0 * w_ntk_factor\n        t_theta = 10000.0 * t_ntk_factor\n\n        h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))\n        w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))\n        temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))\n\n        B, T, H, W, _ = B_T_H_W_C\n        seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)\n        uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())\n        assert (\n            uniform_fps or B == 1 or T == 1\n        ), \"For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1\"\n        half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)\n        half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)\n\n        # apply sequence scaling in temporal dimension\n        if fps is None or self.enable_fps_modulation is False:  # image case\n            half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)\n        else:\n            half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)\n\n        half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)\n        half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)\n        half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)\n\n        em_T_H_W_D = torch.cat(\n            [\n                repeat(half_emb_t, \"t d x -> t h w d x\", h=H, w=W),\n                repeat(half_emb_h, \"h d x -> t h w d x\", t=T, w=W),\n                repeat(half_emb_w, \"w d x -> t h w d x\", t=T, h=H),\n            ]\n            , dim=-2,\n        )\n\n        return rearrange(em_T_H_W_D, \"t h w d (i j) -> (t h w) d i j\", i=2, j=2).float()\n\n\ndef apply_rotary_pos_emb(\n    t: torch.Tensor,\n    freqs: torch.Tensor,\n) -> torch.Tensor:\n    t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()\n    t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]\n    t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)\n    return t_out\n\n\n# ---------------------- Feed Forward Network -----------------------\nclass GPT2FeedForward(nn.Module):\n    def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:\n        super().__init__()\n        self.activation = nn.GELU()\n        self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)\n        self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)\n\n        self._layer_id = None\n        self._dim = d_model\n        self._hidden_dim = d_ff\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.layer1(x)\n\n        x = self.activation(x)\n        x = self.layer2(x)\n        return x\n\n\ndef torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:\n    \"\"\"Computes multi-head attention using PyTorch's native implementation.\n\n    This function provides a PyTorch backend alternative to Transformer Engine's attention operation.\n    It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product\n    attention, and rearranges the output back to the original format.\n\n    The input tensor names use the following dimension conventions:\n\n    - B: batch size\n    - S: sequence length\n    - H: number of attention heads\n    - D: head dimension\n\n    Args:\n        q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)\n        k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)\n        v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)\n\n    Returns:\n        Attention output tensor with shape (batch, seq_len, n_heads * head_dim)\n    \"\"\"\n    in_q_shape = q_B_S_H_D.shape\n    in_k_shape = k_B_S_H_D.shape\n    q_B_H_S_D = rearrange(q_B_S_H_D, \"b ... h k -> b h ... k\").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])\n    k_B_H_S_D = rearrange(k_B_S_H_D, \"b ... h v -> b h ... v\").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])\n    v_B_H_S_D = rearrange(v_B_S_H_D, \"b ... h v -> b h ... v\").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])\n    return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern=\"b s (n d)\")\n\n\nclass Attention(nn.Module):\n    \"\"\"\n    A flexible attention module supporting both self-attention and cross-attention mechanisms.\n\n    This module implements a multi-head attention layer that can operate in either self-attention\n    or cross-attention mode. The mode is determined by whether a context dimension is provided.\n    The implementation uses scaled dot-product attention and supports optional bias terms and\n    dropout regularization.\n\n    Args:\n        query_dim (int): The dimensionality of the query vectors.\n        context_dim (int, optional): The dimensionality of the context (key/value) vectors.\n            If None, the module operates in self-attention mode using query_dim. Default: None\n        n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8\n        head_dim (int, optional): The dimension of each attention head. Default: 64\n        dropout (float, optional): Dropout probability applied to the output. Default: 0.0\n        qkv_format (str, optional): Format specification for QKV tensors. Default: \"bshd\"\n        backend (str, optional): Backend to use for the attention operation. Default: \"transformer_engine\"\n\n    Examples:\n        >>> # Self-attention with 512 dimensions and 8 heads\n        >>> self_attn = Attention(query_dim=512)\n        >>> x = torch.randn(32, 16, 512)  # (batch_size, seq_len, dim)\n        >>> out = self_attn(x)  # (32, 16, 512)\n\n        >>> # Cross-attention\n        >>> cross_attn = Attention(query_dim=512, context_dim=256)\n        >>> query = torch.randn(32, 16, 512)\n        >>> context = torch.randn(32, 8, 256)\n        >>> out = cross_attn(query, context)  # (32, 16, 512)\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        context_dim: Optional[int] = None,\n        n_heads: int = 8,\n        head_dim: int = 64,\n        dropout: float = 0.0,\n        device=None,\n        dtype=None,\n        operations=None,\n    ) -> None:\n        super().__init__()\n        logging.debug(\n            f\"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using \"\n            f\"{n_heads} heads with a dimension of {head_dim}.\"\n        )\n        self.is_selfattn = context_dim is None  # self attention\n\n        context_dim = query_dim if context_dim is None else context_dim\n        inner_dim = head_dim * n_heads\n\n        self.n_heads = n_heads\n        self.head_dim = head_dim\n        self.query_dim = query_dim\n        self.context_dim = context_dim\n\n        self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)\n        self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)\n\n        self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)\n        self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)\n\n        self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)\n        self.v_norm = nn.Identity()\n\n        self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)\n        self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()\n\n        self.attn_op = torch_attention_op\n\n        self._query_dim = query_dim\n        self._context_dim = context_dim\n        self._inner_dim = inner_dim\n\n    def compute_qkv(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n        rope_emb: Optional[torch.Tensor] = None,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        q = self.q_proj(x)\n        context = x if context is None else context\n        k = self.k_proj(context)\n        v = self.v_proj(context)\n        q, k, v = map(\n            lambda t: rearrange(t, \"b ... (h d) -> b ... h d\", h=self.n_heads, d=self.head_dim),\n            (q, k, v),\n        )\n\n        def apply_norm_and_rotary_pos_emb(\n            q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]\n        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n            q = self.q_norm(q)\n            k = self.k_norm(k)\n            v = self.v_norm(v)\n            if self.is_selfattn and rope_emb is not None:  # only apply to self-attention!\n                q = apply_rotary_pos_emb(q, rope_emb)\n                k = apply_rotary_pos_emb(k, rope_emb)\n            return q, k, v\n\n        q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)\n\n        return q, k, v\n\n    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:\n        result = self.attn_op(q, k, v, transformer_options=transformer_options)  # [B, S, H, D]\n        return self.output_dropout(self.output_proj(result))\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n        rope_emb: Optional[torch.Tensor] = None,\n        transformer_options: Optional[dict] = {},\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x (Tensor): The query tensor of shape [B, Mq, K]\n            context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None\n        \"\"\"\n        q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)\n        return self.compute_attention(q, k, v, transformer_options=transformer_options)\n\n\nclass Timesteps(nn.Module):\n    def __init__(self, num_channels: int):\n        super().__init__()\n        self.num_channels = num_channels\n\n    def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:\n        assert timesteps_B_T.ndim == 2, f\"Expected 2D input, got {timesteps_B_T.ndim}\"\n        timesteps = timesteps_B_T.flatten().float()\n        half_dim = self.num_channels // 2\n        exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)\n        exponent = exponent / (half_dim - 0.0)\n\n        emb = torch.exp(exponent)\n        emb = timesteps[:, None].float() * emb[None, :]\n\n        sin_emb = torch.sin(emb)\n        cos_emb = torch.cos(emb)\n        emb = torch.cat([cos_emb, sin_emb], dim=-1)\n\n        return rearrange(emb, \"(b t) d -> b t d\", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])\n\n\nclass TimestepEmbedding(nn.Module):\n    def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):\n        super().__init__()\n        logging.debug(\n            f\"Using AdaLN LoRA Flag:  {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility.\"\n        )\n        self.in_dim = in_features\n        self.out_dim = out_features\n        self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)\n        self.activation = nn.SiLU()\n        self.use_adaln_lora = use_adaln_lora\n        if use_adaln_lora:\n            self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)\n        else:\n            self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)\n\n    def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        emb = self.linear_1(sample)\n        emb = self.activation(emb)\n        emb = self.linear_2(emb)\n\n        if self.use_adaln_lora:\n            adaln_lora_B_T_3D = emb\n            emb_B_T_D = sample\n        else:\n            adaln_lora_B_T_3D = None\n            emb_B_T_D = emb\n\n        return emb_B_T_D, adaln_lora_B_T_3D\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,\n    depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,\n    making it suitable for video and image processing tasks. It supports dividing the input into patches\n    and embedding each patch into a vector of size `out_channels`.\n\n    Parameters:\n    - spatial_patch_size (int): The size of each spatial patch.\n    - temporal_patch_size (int): The size of each temporal patch.\n    - in_channels (int): Number of input channels. Default: 3.\n    - out_channels (int): The dimension of the embedding vector for each patch. Default: 768.\n    - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_patch_size: int,\n        temporal_patch_size: int,\n        in_channels: int = 3,\n        out_channels: int = 768,\n        device=None, dtype=None, operations=None\n    ):\n        super().__init__()\n        self.spatial_patch_size = spatial_patch_size\n        self.temporal_patch_size = temporal_patch_size\n\n        self.proj = nn.Sequential(\n            Rearrange(\n                \"b c (t r) (h m) (w n) -> b t h w (c r m n)\",\n                r=temporal_patch_size,\n                m=spatial_patch_size,\n                n=spatial_patch_size,\n            ),\n            operations.Linear(\n                in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype\n            ),\n        )\n        self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass of the PatchEmbed module.\n\n        Parameters:\n        - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where\n            B is the batch size,\n            C is the number of channels,\n            T is the temporal dimension,\n            H is the height, and\n            W is the width of the input.\n\n        Returns:\n        - torch.Tensor: The embedded patches as a tensor, with shape b t h w c.\n        \"\"\"\n        assert x.dim() == 5\n        _, _, T, H, W = x.shape\n        assert (\n            H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0\n        ), f\"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}\"\n        assert T % self.temporal_patch_size == 0\n        x = self.proj(x)\n        return x\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of video DiT.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        spatial_patch_size: int,\n        temporal_patch_size: int,\n        out_channels: int,\n        use_adaln_lora: bool = False,\n        adaln_lora_dim: int = 256,\n        device=None, dtype=None, operations=None\n    ):\n        super().__init__()\n        self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = operations.Linear(\n            hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype\n        )\n        self.hidden_size = hidden_size\n        self.n_adaln_chunks = 2\n        self.use_adaln_lora = use_adaln_lora\n        self.adaln_lora_dim = adaln_lora_dim\n        if use_adaln_lora:\n            self.adaln_modulation = nn.Sequential(\n                nn.SiLU(),\n                operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),\n                operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),\n            )\n        else:\n            self.adaln_modulation = nn.Sequential(\n                nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)\n            )\n\n    def forward(\n        self,\n        x_B_T_H_W_D: torch.Tensor,\n        emb_B_T_D: torch.Tensor,\n        adaln_lora_B_T_3D: Optional[torch.Tensor] = None,\n    ):\n        if self.use_adaln_lora:\n            assert adaln_lora_B_T_3D is not None\n            shift_B_T_D, scale_B_T_D = (\n                self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]\n            ).chunk(2, dim=-1)\n        else:\n            shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)\n\n        shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, \"b t d -> b t 1 1 d\"), rearrange(\n            scale_B_T_D, \"b t d -> b t 1 1 d\"\n        )\n\n        def _fn(\n            _x_B_T_H_W_D: torch.Tensor,\n            _norm_layer: nn.Module,\n            _scale_B_T_1_1_D: torch.Tensor,\n            _shift_B_T_1_1_D: torch.Tensor,\n        ) -> torch.Tensor:\n            return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D\n\n        x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)\n        x_B_T_H_W_O = self.linear(x_B_T_H_W_D)\n        return x_B_T_H_W_O\n\n\nclass Block(nn.Module):\n    \"\"\"\n    A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.\n    Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.\n\n    Parameters:\n        x_dim (int): Dimension of input features\n        context_dim (int): Dimension of context features for cross-attention\n        num_heads (int): Number of attention heads\n        mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0\n        use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False\n        adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256\n\n    The block applies the following sequence:\n    1. Self-attention with AdaLN modulation\n    2. Cross-attention with AdaLN modulation\n    3. MLP with AdaLN modulation\n\n    Each component uses skip connections and layer normalization.\n    \"\"\"\n\n    def __init__(\n        self,\n        x_dim: int,\n        context_dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        use_adaln_lora: bool = False,\n        adaln_lora_dim: int = 256,\n        device=None,\n        dtype=None,\n        operations=None,\n    ):\n        super().__init__()\n        self.x_dim = x_dim\n        self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)\n        self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)\n\n        self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)\n        self.cross_attn = Attention(\n            x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations\n        )\n\n        self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)\n        self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)\n\n        self.use_adaln_lora = use_adaln_lora\n        if self.use_adaln_lora:\n            self.adaln_modulation_self_attn = nn.Sequential(\n                nn.SiLU(),\n                operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),\n                operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),\n            )\n            self.adaln_modulation_cross_attn = nn.Sequential(\n                nn.SiLU(),\n                operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),\n                operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),\n            )\n            self.adaln_modulation_mlp = nn.Sequential(\n                nn.SiLU(),\n                operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),\n                operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),\n            )\n        else:\n            self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))\n            self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))\n            self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))\n\n    def forward(\n        self,\n        x_B_T_H_W_D: torch.Tensor,\n        emb_B_T_D: torch.Tensor,\n        crossattn_emb: torch.Tensor,\n        rope_emb_L_1_1_D: Optional[torch.Tensor] = None,\n        adaln_lora_B_T_3D: Optional[torch.Tensor] = None,\n        extra_per_block_pos_emb: Optional[torch.Tensor] = None,\n        transformer_options: Optional[dict] = {},\n    ) -> torch.Tensor:\n        residual_dtype = x_B_T_H_W_D.dtype\n        compute_dtype = emb_B_T_D.dtype\n        if extra_per_block_pos_emb is not None:\n            x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb\n\n        if self.use_adaln_lora:\n            shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (\n                self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D\n            ).chunk(3, dim=-1)\n            shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (\n                self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D\n            ).chunk(3, dim=-1)\n            shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (\n                self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D\n            ).chunk(3, dim=-1)\n        else:\n            shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(\n                emb_B_T_D\n            ).chunk(3, dim=-1)\n            shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(\n                emb_B_T_D\n            ).chunk(3, dim=-1)\n            shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)\n\n        # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting\n        shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, \"b t d -> b t 1 1 d\")\n        scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, \"b t d -> b t 1 1 d\")\n        gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, \"b t d -> b t 1 1 d\")\n\n        shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, \"b t d -> b t 1 1 d\")\n        scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, \"b t d -> b t 1 1 d\")\n        gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, \"b t d -> b t 1 1 d\")\n\n        shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, \"b t d -> b t 1 1 d\")\n        scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, \"b t d -> b t 1 1 d\")\n        gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, \"b t d -> b t 1 1 d\")\n\n        B, T, H, W, D = x_B_T_H_W_D.shape\n\n        def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):\n            return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D\n\n        normalized_x_B_T_H_W_D = _fn(\n            x_B_T_H_W_D,\n            self.layer_norm_self_attn,\n            scale_self_attn_B_T_1_1_D,\n            shift_self_attn_B_T_1_1_D,\n        )\n        result_B_T_H_W_D = rearrange(\n            self.self_attn(\n                # normalized_x_B_T_HW_D,\n                rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), \"b t h w d -> b (t h w) d\"),\n                None,\n                rope_emb=rope_emb_L_1_1_D,\n                transformer_options=transformer_options,\n            ),\n            \"b (t h w) d -> b t h w d\",\n            t=T,\n            h=H,\n            w=W,\n        )\n        x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)\n\n        def _x_fn(\n            _x_B_T_H_W_D: torch.Tensor,\n            layer_norm_cross_attn: Callable,\n            _scale_cross_attn_B_T_1_1_D: torch.Tensor,\n            _shift_cross_attn_B_T_1_1_D: torch.Tensor,\n            transformer_options: Optional[dict] = {},\n        ) -> torch.Tensor:\n            _normalized_x_B_T_H_W_D = _fn(\n                _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D\n            )\n            _result_B_T_H_W_D = rearrange(\n                self.cross_attn(\n                    rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), \"b t h w d -> b (t h w) d\"),\n                    crossattn_emb,\n                    rope_emb=rope_emb_L_1_1_D,\n                    transformer_options=transformer_options,\n                ),\n                \"b (t h w) d -> b t h w d\",\n                t=T,\n                h=H,\n                w=W,\n            )\n            return _result_B_T_H_W_D\n\n        result_B_T_H_W_D = _x_fn(\n            x_B_T_H_W_D,\n            self.layer_norm_cross_attn,\n            scale_cross_attn_B_T_1_1_D,\n            shift_cross_attn_B_T_1_1_D,\n            transformer_options=transformer_options,\n        )\n        x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D\n\n        normalized_x_B_T_H_W_D = _fn(\n            x_B_T_H_W_D,\n            self.layer_norm_mlp,\n            scale_mlp_B_T_1_1_D,\n            shift_mlp_B_T_1_1_D,\n        )\n        result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))\n        x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)\n        return x_B_T_H_W_D\n\n\nclass MiniTrainDIT(nn.Module):\n    \"\"\"\n    A clean impl of DIT that can load and  reproduce the training results of the original DIT model in~(cosmos 1)\n    A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.\n\n    Args:\n        max_img_h (int): Maximum height of the input images.\n        max_img_w (int): Maximum width of the input images.\n        max_frames (int): Maximum number of frames in the video sequence.\n        in_channels (int): Number of input channels (e.g., RGB channels for color images).\n        out_channels (int): Number of output channels.\n        patch_spatial (tuple): Spatial resolution of patches for input processing.\n        patch_temporal (int): Temporal resolution of patches for input processing.\n        concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.\n        model_channels (int): Base number of channels used throughout the model.\n        num_blocks (int): Number of transformer blocks.\n        num_heads (int): Number of heads in the multi-head attention layers.\n        mlp_ratio (float): Expansion ratio for MLP blocks.\n        crossattn_emb_channels (int): Number of embedding channels for cross-attention.\n        pos_emb_cls (str): Type of positional embeddings.\n        pos_emb_learnable (bool): Whether positional embeddings are learnable.\n        pos_emb_interpolation (str): Method for interpolating positional embeddings.\n        min_fps (int): Minimum frames per second.\n        max_fps (int): Maximum frames per second.\n        use_adaln_lora (bool): Whether to use AdaLN-LoRA.\n        adaln_lora_dim (int): Dimension for AdaLN-LoRA.\n        rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.\n        rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.\n        rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.\n        extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.\n        extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.\n        extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.\n        extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_img_h: int,\n        max_img_w: int,\n        max_frames: int,\n        in_channels: int,\n        out_channels: int,\n        patch_spatial: int,  # tuple,\n        patch_temporal: int,\n        concat_padding_mask: bool = True,\n        # attention settings\n        model_channels: int = 768,\n        num_blocks: int = 10,\n        num_heads: int = 16,\n        mlp_ratio: float = 4.0,\n        # cross attention settings\n        crossattn_emb_channels: int = 1024,\n        # positional embedding settings\n        pos_emb_cls: str = \"sincos\",\n        pos_emb_learnable: bool = False,\n        pos_emb_interpolation: str = \"crop\",\n        min_fps: int = 1,\n        max_fps: int = 30,\n        use_adaln_lora: bool = False,\n        adaln_lora_dim: int = 256,\n        rope_h_extrapolation_ratio: float = 1.0,\n        rope_w_extrapolation_ratio: float = 1.0,\n        rope_t_extrapolation_ratio: float = 1.0,\n        extra_per_block_abs_pos_emb: bool = False,\n        extra_h_extrapolation_ratio: float = 1.0,\n        extra_w_extrapolation_ratio: float = 1.0,\n        extra_t_extrapolation_ratio: float = 1.0,\n        rope_enable_fps_modulation: bool = True,\n        image_model=None,\n        device=None,\n        dtype=None,\n        operations=None,\n    ) -> None:\n        super().__init__()\n        self.dtype = dtype\n        self.max_img_h = max_img_h\n        self.max_img_w = max_img_w\n        self.max_frames = max_frames\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.patch_spatial = patch_spatial\n        self.patch_temporal = patch_temporal\n        self.num_heads = num_heads\n        self.num_blocks = num_blocks\n        self.model_channels = model_channels\n        self.concat_padding_mask = concat_padding_mask\n        # positional embedding settings\n        self.pos_emb_cls = pos_emb_cls\n        self.pos_emb_learnable = pos_emb_learnable\n        self.pos_emb_interpolation = pos_emb_interpolation\n        self.min_fps = min_fps\n        self.max_fps = max_fps\n        self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio\n        self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio\n        self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio\n        self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb\n        self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio\n        self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio\n        self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio\n        self.rope_enable_fps_modulation = rope_enable_fps_modulation\n\n        self.build_pos_embed(device=device, dtype=dtype)\n        self.use_adaln_lora = use_adaln_lora\n        self.adaln_lora_dim = adaln_lora_dim\n        self.t_embedder = nn.Sequential(\n            Timesteps(model_channels),\n            TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),\n        )\n\n        in_channels = in_channels + 1 if concat_padding_mask else in_channels\n        self.x_embedder = PatchEmbed(\n            spatial_patch_size=patch_spatial,\n            temporal_patch_size=patch_temporal,\n            in_channels=in_channels,\n            out_channels=model_channels,\n            device=device, dtype=dtype, operations=operations,\n        )\n\n        self.blocks = nn.ModuleList(\n            [\n                Block(\n                    x_dim=model_channels,\n                    context_dim=crossattn_emb_channels,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    use_adaln_lora=use_adaln_lora,\n                    adaln_lora_dim=adaln_lora_dim,\n                    device=device, dtype=dtype, operations=operations,\n                )\n                for _ in range(num_blocks)\n            ]\n        )\n\n        self.final_layer = FinalLayer(\n            hidden_size=self.model_channels,\n            spatial_patch_size=self.patch_spatial,\n            temporal_patch_size=self.patch_temporal,\n            out_channels=self.out_channels,\n            use_adaln_lora=self.use_adaln_lora,\n            adaln_lora_dim=self.adaln_lora_dim,\n            device=device, dtype=dtype, operations=operations,\n        )\n\n        self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)\n\n    def build_pos_embed(self, device=None, dtype=None) -> None:\n        if self.pos_emb_cls == \"rope3d\":\n            cls_type = VideoRopePosition3DEmb\n        else:\n            raise ValueError(f\"Unknown pos_emb_cls {self.pos_emb_cls}\")\n\n        logging.debug(f\"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}\")\n        kwargs = dict(\n            model_channels=self.model_channels,\n            len_h=self.max_img_h // self.patch_spatial,\n            len_w=self.max_img_w // self.patch_spatial,\n            len_t=self.max_frames // self.patch_temporal,\n            max_fps=self.max_fps,\n            min_fps=self.min_fps,\n            is_learnable=self.pos_emb_learnable,\n            interpolation=self.pos_emb_interpolation,\n            head_dim=self.model_channels // self.num_heads,\n            h_extrapolation_ratio=self.rope_h_extrapolation_ratio,\n            w_extrapolation_ratio=self.rope_w_extrapolation_ratio,\n            t_extrapolation_ratio=self.rope_t_extrapolation_ratio,\n            enable_fps_modulation=self.rope_enable_fps_modulation,\n            device=device,\n        )\n        self.pos_embedder = cls_type(\n            **kwargs,  # type: ignore\n        )\n\n        if self.extra_per_block_abs_pos_emb:\n            kwargs[\"h_extrapolation_ratio\"] = self.extra_h_extrapolation_ratio\n            kwargs[\"w_extrapolation_ratio\"] = self.extra_w_extrapolation_ratio\n            kwargs[\"t_extrapolation_ratio\"] = self.extra_t_extrapolation_ratio\n            kwargs[\"device\"] = device\n            kwargs[\"dtype\"] = dtype\n            self.extra_pos_embedder = LearnablePosEmbAxis(\n                **kwargs,  # type: ignore\n            )\n\n    def prepare_embedded_sequence(\n        self,\n        x_B_C_T_H_W: torch.Tensor,\n        fps: Optional[torch.Tensor] = None,\n        padding_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:\n        \"\"\"\n        Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.\n\n        Args:\n            x_B_C_T_H_W (torch.Tensor): video\n            fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.\n                                    If None, a default value (`self.base_fps`) will be used.\n            padding_mask (Optional[torch.Tensor]): current it is not used\n\n        Returns:\n            Tuple[torch.Tensor, Optional[torch.Tensor]]:\n                - A tensor of shape (B, T, H, W, D) with the embedded sequence.\n                - An optional positional embedding tensor, returned only if the positional embedding class\n                (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.\n\n        Notes:\n            - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.\n            - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.\n            - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using\n                the `self.pos_embedder` with the shape [T, H, W].\n            - If \"fps_aware\" is in `self.pos_emb_cls`, the positional embeddings are generated using the\n            `self.pos_embedder` with the fps tensor.\n            - Otherwise, the positional embeddings are generated without considering fps.\n        \"\"\"\n        if self.concat_padding_mask:\n            if padding_mask is None:\n                padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)\n            else:\n                padding_mask = transforms.functional.resize(\n                    padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST\n                )\n            x_B_C_T_H_W = torch.cat(\n                [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1\n            )\n        x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)\n\n        if self.extra_per_block_abs_pos_emb:\n            extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)\n        else:\n            extra_pos_emb = None\n\n        if \"rope\" in self.pos_emb_cls.lower():\n            return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb\n        x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device)  # [B, T, H, W, D]\n\n        return x_B_T_H_W_D, None, extra_pos_emb\n\n    def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:\n        x_B_C_Tt_Hp_Wp = rearrange(\n            x_B_T_H_W_M,\n            \"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)\",\n            p1=self.patch_spatial,\n            p2=self.patch_spatial,\n            t=self.patch_temporal,\n        )\n        return x_B_C_Tt_Hp_Wp\n    \n    def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode=\"circular\"):\n        if padding_mode == \"circular\" and (torch.jit.is_tracing() or torch.jit.is_scripting()):\n            padding_mode = \"reflect\"\n\n        pad = ()\n        for i in range(img.ndim - 2):\n            pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad\n\n        return torch.nn.functional.pad(img, pad, mode=padding_mode)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        context: torch.Tensor,\n        fps: Optional[torch.Tensor] = None,\n        padding_mask: Optional[torch.Tensor] = None,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n        **kwargs,\n    ):\n        orig_shape = list(x.shape)\n        x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))\n        x_B_C_T_H_W = x\n        timesteps_B_T = timesteps\n        crossattn_emb = context\n        \"\"\"\n        Args:\n            x: (B, C, T, H, W) tensor of spatial-temp inputs\n            timesteps: (B, ) tensor of timesteps\n            crossattn_emb: (B, N, D) tensor of cross-attention embeddings\n        \"\"\"\n        x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(\n            x_B_C_T_H_W,\n            fps=fps,\n            padding_mask=padding_mask,\n        )\n\n        if timesteps_B_T.ndim == 1:\n            timesteps_B_T = timesteps_B_T.unsqueeze(1)\n        t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))\n        t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)\n\n        # for logging purpose\n        affline_scale_log_info = {}\n        affline_scale_log_info[\"t_embedding_B_T_D\"] = t_embedding_B_T_D.detach()\n        self.affline_scale_log_info = affline_scale_log_info\n        self.affline_emb = t_embedding_B_T_D\n        self.crossattn_emb = crossattn_emb\n\n        if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:\n            assert (\n                x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape\n            ), f\"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}\"\n\n        block_kwargs = {\n            \"rope_emb_L_1_1_D\": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),\n            \"adaln_lora_B_T_3D\": adaln_lora_B_T_3D,\n            \"extra_per_block_pos_emb\": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,\n            \"transformer_options\": kwargs.get(\"transformer_options\", {}),\n        }\n\n        # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream\n        # in fp32, but run attention and MLP modules in fp16.\n        # An alternate method that clamps fp16 values \"works\" in the sense that it makes coherent images, but there is noticeable\n        # quality degradation and visual artifacts.\n        if x_B_T_H_W_D.dtype == torch.float16:\n            x_B_T_H_W_D = x_B_T_H_W_D.float()\n\n        for block in self.blocks:\n            x_B_T_H_W_D = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                x_B_T_H_W_D=x_B_T_H_W_D,\n                emb_B_T_D=t_embedding_B_T_D,\n                crossattn_emb=crossattn_emb,\n                **block_kwargs,\n            )\n\n        x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)\n        x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]\n        return x_B_C_Tt_Hp_Wp\n\n\ndef rotate_half(x):\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1):\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    x_embed = (x * cos) + (rotate_half(x) * sin)\n    return x_embed\n\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, head_dim):\n        super().__init__()\n        self.rope_theta = 10000\n        inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass LLMAdapterAttention(nn.Module):\n    def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):\n        super().__init__()\n\n        inner_dim = head_dim * n_heads\n        self.n_heads = n_heads\n        self.head_dim = head_dim\n        self.query_dim = query_dim\n        self.context_dim = context_dim\n\n        self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)\n        self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)\n\n        self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)\n        self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)\n\n        self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)\n\n        self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)\n\n    def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):\n        context = x if context is None else context\n        input_shape = x.shape[:-1]\n        q_shape = (*input_shape, self.n_heads, self.head_dim)\n        context_shape = context.shape[:-1]\n        kv_shape = (*context_shape, self.n_heads, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)\n        value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)\n\n        if position_embeddings is not None:\n            assert position_embeddings_context is not None\n            cos, sin = position_embeddings\n            query_states = apply_rotary_pos_emb2(query_states, cos, sin)\n            cos, sin = position_embeddings_context\n            key_states = apply_rotary_pos_emb2(key_states, cos, sin)\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)\n\n        attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n    def init_weights(self):\n        torch.nn.init.zeros_(self.o_proj.weight)\n\n\nclass LLMAdapterTransformerBlock(nn.Module):\n    def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):\n        super().__init__()\n        self.use_self_attn = use_self_attn\n\n        if self.use_self_attn:\n            self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)\n            self.self_attn = LLMAdapterAttention(\n                query_dim=model_dim,\n                context_dim=model_dim,\n                n_heads=num_heads,\n                head_dim=model_dim//num_heads,\n                device=device,\n                dtype=dtype,\n                operations=operations,\n            )\n\n        self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)\n        self.cross_attn = LLMAdapterAttention(\n            query_dim=model_dim,\n            context_dim=source_dim,\n            n_heads=num_heads,\n            head_dim=model_dim//num_heads,\n            device=device,\n            dtype=dtype,\n            operations=operations,\n        )\n\n        self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)\n        self.mlp = nn.Sequential(\n            operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),\n            nn.GELU(),\n            operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)\n        )\n\n    def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):\n        if self.use_self_attn:\n            normed = self.norm_self_attn(x)\n            attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)\n            x = x + attn_out\n\n        normed = self.norm_cross_attn(x)\n        attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)\n        x = x + attn_out\n\n        x = x + self.mlp(self.norm_mlp(x))\n        return x\n\n    def init_weights(self):\n        torch.nn.init.zeros_(self.mlp[2].weight)\n        self.cross_attn.init_weights()\n\n\nclass LLMAdapter(nn.Module):\n    def __init__(\n            self,\n            source_dim=1024,\n            target_dim=1024,\n            model_dim=1024,\n            num_layers=6,\n            num_heads=16,\n            use_self_attn=True,\n            layer_norm=False,\n            device=None,\n            dtype=None,\n            operations=None,\n        ):\n        super().__init__()\n\n        self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)\n        if model_dim != target_dim:\n            self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)\n        else:\n            self.in_proj = nn.Identity()\n        self.rotary_emb = RotaryEmbedding(model_dim//num_heads)\n        self.blocks = nn.ModuleList([\n            LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)\n        ])\n        self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)\n        self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)\n\n    def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):\n        if target_attention_mask is not None:\n            target_attention_mask = target_attention_mask.to(torch.bool)\n            if target_attention_mask.ndim == 2:\n                target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)\n\n        if source_attention_mask is not None:\n            source_attention_mask = source_attention_mask.to(torch.bool)\n            if source_attention_mask.ndim == 2:\n                source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)\n\n        context = source_hidden_states\n        x = self.in_proj(self.embed(target_input_ids).to(context.dtype))\n        position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)\n        position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)\n        position_embeddings = self.rotary_emb(x, position_ids)\n        position_embeddings_context = self.rotary_emb(x, position_ids_context)\n        for block in self.blocks:\n            x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)\n        return self.norm(self.out_proj(x))\n\n\nclass AnimaDiT(MiniTrainDIT):\n    def __init__(self):\n        kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}\n        super().__init__(**kwargs)\n        self.llm_adapter = LLMAdapter(device=kwargs.get(\"device\"), dtype=kwargs.get(\"dtype\"), operations=kwargs.get(\"operations\"))\n\n    def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):\n        if text_ids is not None:\n            out = self.llm_adapter(text_embeds, text_ids)\n            if t5xxl_weights is not None:\n                out = out * t5xxl_weights\n\n            if out.shape[1] < 512:\n                out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))\n            return out\n        else:\n            return text_embeds\n\n    def forward(\n        self,\n        x, timesteps, context,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n        **kwargs\n    ):\n        t5xxl_ids = kwargs.pop(\"t5xxl_ids\", None)\n        if t5xxl_ids is not None:\n            context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop(\"t5xxl_weights\", None))\n        return super().forward(\n            x, timesteps, context,\n            use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            **kwargs\n        )\n"
  },
  {
    "path": "diffsynth/models/dinov3_image_encoder.py",
    "content": "from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast\nfrom transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig\nimport torch\n\nfrom ..core.device.npu_compatible_device import get_device_type\n\n\nclass DINOv3ImageEncoder(DINOv3ViTModel):\n    def __init__(self):\n        config = DINOv3ViTConfig(\n            architectures = [\n                \"DINOv3ViTModel\"\n            ],\n            attention_dropout = 0.0,\n            drop_path_rate = 0.0,\n            dtype = \"float32\",\n            hidden_act = \"silu\",\n            hidden_size = 4096,\n            image_size = 224,\n            initializer_range = 0.02,\n            intermediate_size = 8192,\n            key_bias = False,\n            layer_norm_eps = 1e-05,\n            layerscale_value = 1.0,\n            mlp_bias = True,\n            model_type = \"dinov3_vit\",\n            num_attention_heads = 32,\n            num_channels = 3,\n            num_hidden_layers = 40,\n            num_register_tokens = 4,\n            patch_size = 16,\n            pos_embed_jitter = None,\n            pos_embed_rescale = 2.0,\n            pos_embed_shift = None,\n            proj_bias = True,\n            query_bias = False,\n            rope_theta = 100.0,\n            transformers_version = \"4.56.1\",\n            use_gated_mlp = True,\n            value_bias = False\n        )\n        super().__init__(config)\n        self.processor = DINOv3ViTImageProcessorFast(\n            crop_size = None,\n            data_format = \"channels_first\",\n            default_to_square = True,\n            device = None,\n            disable_grouping = None,\n            do_center_crop = None,\n            do_convert_rgb = None,\n            do_normalize = True,\n            do_rescale = True,\n            do_resize = True,\n            image_mean = [\n                0.485,\n                0.456,\n                0.406\n            ],\n            image_processor_type = \"DINOv3ViTImageProcessorFast\",\n            image_std = [\n                0.229,\n                0.224,\n                0.225\n            ],\n            input_data_format = None,\n            resample = 2,\n            rescale_factor = 0.00392156862745098,\n            return_tensors = None,\n            size = {\n                \"height\": 224,\n                \"width\": 224\n            }\n        )\n        \n    def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):\n        inputs = self.processor(images=image, return_tensors=\"pt\")\n        pixel_values = inputs[\"pixel_values\"].to(dtype=torch_dtype, device=device)\n        bool_masked_pos = None\n        head_mask = None\n        \n        pixel_values = pixel_values.to(torch_dtype)\n        hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)\n        position_embeddings = self.rope_embeddings(pixel_values)\n\n        for i, layer_module in enumerate(self.layer):\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            hidden_states = layer_module(\n                hidden_states,\n                attention_mask=layer_head_mask,\n                position_embeddings=position_embeddings,\n            )\n\n        sequence_output = self.norm(hidden_states)\n        pooled_output = sequence_output[:, 0, :]\n\n        return pooled_output\n"
  },
  {
    "path": "diffsynth/models/flux2_dit.py",
    "content": "import inspect\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch, math\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom ..core.attention import attention_forward\nfrom ..core.gradient import gradient_checkpoint_forward\n\n\ndef get_timestep_embedding(\n    timesteps: torch.Tensor,\n    embedding_dim: int,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 1,\n    scale: float = 1,\n    max_period: int = 10000,\n) -> torch.Tensor:\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.\n\n    Args\n        timesteps (torch.Tensor):\n            a 1-D Tensor of N indices, one per batch element. These may be fractional.\n        embedding_dim (int):\n            the dimension of the output.\n        flip_sin_to_cos (bool):\n            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)\n        downscale_freq_shift (float):\n            Controls the delta between frequencies between dimensions\n        scale (float):\n            Scaling factor applied to the embeddings.\n        max_period (int):\n            Controls the maximum frequency of the embeddings\n    Returns\n        torch.Tensor: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n\n    half_dim = embedding_dim // 2\n    exponent = -math.log(max_period) * torch.arange(\n        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device\n    )\n    exponent = exponent / (half_dim - downscale_freq_shift)\n\n    emb = torch.exp(exponent)\n    emb = timesteps[:, None].float() * emb[None, :]\n\n    # scale embeddings\n    emb = scale * emb\n\n    # concat sine and cosine embeddings\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n\n    # flip sine and cosine embeddings\n    if flip_sin_to_cos:\n        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)\n\n    # zero pad\n    if embedding_dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\nclass TimestepEmbedding(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        time_embed_dim: int,\n        act_fn: str = \"silu\",\n        out_dim: int = None,\n        post_act_fn: Optional[str] = None,\n        cond_proj_dim=None,\n        sample_proj_bias=True,\n    ):\n        super().__init__()\n\n        self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)\n\n        if cond_proj_dim is not None:\n            self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)\n        else:\n            self.cond_proj = None\n\n        self.act = torch.nn.SiLU()\n\n        if out_dim is not None:\n            time_embed_dim_out = out_dim\n        else:\n            time_embed_dim_out = time_embed_dim\n        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)\n\n        if post_act_fn is None:\n            self.post_act = None\n\n    def forward(self, sample, condition=None):\n        if condition is not None:\n            sample = sample + self.cond_proj(condition)\n        sample = self.linear_1(sample)\n\n        if self.act is not None:\n            sample = self.act(sample)\n\n        sample = self.linear_2(sample)\n\n        if self.post_act is not None:\n            sample = self.post_act(sample)\n        return sample\n\n\nclass Timesteps(nn.Module):\n    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):\n        super().__init__()\n        self.num_channels = num_channels\n        self.flip_sin_to_cos = flip_sin_to_cos\n        self.downscale_freq_shift = downscale_freq_shift\n        self.scale = scale\n\n    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:\n        t_emb = get_timestep_embedding(\n            timesteps,\n            self.num_channels,\n            flip_sin_to_cos=self.flip_sin_to_cos,\n            downscale_freq_shift=self.downscale_freq_shift,\n            scale=self.scale,\n        )\n        return t_emb\n\n\nclass AdaLayerNormContinuous(nn.Module):\n    r\"\"\"\n    Adaptive normalization layer with a norm layer (layer_norm or rms_norm).\n\n    Args:\n        embedding_dim (`int`): Embedding dimension to use during projection.\n        conditioning_embedding_dim (`int`): Dimension of the input condition.\n        elementwise_affine (`bool`, defaults to `True`):\n            Boolean flag to denote if affine transformation should be applied.\n        eps (`float`, defaults to 1e-5): Epsilon factor.\n        bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.\n        norm_type (`str`, defaults to `\"layer_norm\"`):\n            Normalization layer to use. Values supported: \"layer_norm\", \"rms_norm\".\n    \"\"\"\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        conditioning_embedding_dim: int,\n        # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters\n        # because the output is immediately scaled and shifted by the projected conditioning embeddings.\n        # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.\n        # However, this is how it was implemented in the original code, and it's rather likely you should\n        # set `elementwise_affine` to False.\n        elementwise_affine=True,\n        eps=1e-5,\n        bias=True,\n        norm_type=\"layer_norm\",\n    ):\n        super().__init__()\n        self.silu = nn.SiLU()\n        self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)\n        if norm_type == \"layer_norm\":\n            self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)\n\n    def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:\n        # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)\n        emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))\n        scale, shift = torch.chunk(emb, 2, dim=1)\n        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]\n        return x\n\n\ndef get_1d_rotary_pos_embed(\n    dim: int,\n    pos: Union[np.ndarray, int],\n    theta: float = 10000.0,\n    use_real=False,\n    linear_factor=1.0,\n    ntk_factor=1.0,\n    repeat_interleave_real=True,\n    freqs_dtype=torch.float32,  #  torch.float32, torch.float64 (flux)\n):\n    \"\"\"\n    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.\n\n    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end\n    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64\n    data type.\n\n    Args:\n        dim (`int`): Dimension of the frequency tensor.\n        pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar\n        theta (`float`, *optional*, defaults to 10000.0):\n            Scaling factor for frequency computation. Defaults to 10000.0.\n        use_real (`bool`, *optional*):\n            If True, return real part and imaginary part separately. Otherwise, return complex numbers.\n        linear_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor for the context extrapolation. Defaults to 1.0.\n        ntk_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.\n        repeat_interleave_real (`bool`, *optional*, defaults to `True`):\n            If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.\n            Otherwise, they are concateanted with themselves.\n        freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):\n            the dtype of the frequency tensor.\n    Returns:\n        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]\n    \"\"\"\n    assert dim % 2 == 0\n\n    if isinstance(pos, int):\n        pos = torch.arange(pos)\n    if isinstance(pos, np.ndarray):\n        pos = torch.from_numpy(pos)  # type: ignore  # [S]\n\n    theta = theta * ntk_factor\n    freqs = (\n        1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor\n    )  # [D/2]\n    freqs = torch.outer(pos, freqs)  # type: ignore   # [S, D/2]\n    is_npu = freqs.device.type == \"npu\"\n    if is_npu:\n        freqs = freqs.float()\n    if use_real and repeat_interleave_real:\n        # flux, hunyuan-dit, cogvideox\n        freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()  # [S, D]\n        freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()  # [S, D]\n        return freqs_cos, freqs_sin\n    elif use_real:\n        # stable audio, allegro\n        freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()  # [S, D]\n        freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()  # [S, D]\n        return freqs_cos, freqs_sin\n    else:\n        # lumina\n        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]\n        return freqs_cis\n\n\ndef apply_rotary_emb(\n    x: torch.Tensor,\n    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],\n    use_real: bool = True,\n    use_real_unbind_dim: int = -1,\n    sequence_dim: int = 2,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings\n    to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are\n    reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting\n    tensors contain rotary embeddings and are returned as real tensors.\n\n    Args:\n        x (`torch.Tensor`):\n            Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply\n        freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.\n    \"\"\"\n    if use_real:\n        cos, sin = freqs_cis  # [S, D]\n        if sequence_dim == 2:\n            cos = cos[None, None, :, :]\n            sin = sin[None, None, :, :]\n        elif sequence_dim == 1:\n            cos = cos[None, :, None, :]\n            sin = sin[None, :, None, :]\n        else:\n            raise ValueError(f\"`sequence_dim={sequence_dim}` but should be 1 or 2.\")\n\n        cos, sin = cos.to(x.device), sin.to(x.device)\n\n        if use_real_unbind_dim == -1:\n            # Used for flux, cogvideox, hunyuan-dit\n            x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, H, S, D//2]\n            x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)\n        elif use_real_unbind_dim == -2:\n            # Used for Stable Audio, OmniGen, CogView4 and Cosmos\n            x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)  # [B, H, S, D//2]\n            x_rotated = torch.cat([-x_imag, x_real], dim=-1)\n        else:\n            raise ValueError(f\"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.\")\n\n        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)\n\n        return out\n    else:\n        # used for lumina\n        x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))\n        freqs_cis = freqs_cis.unsqueeze(2)\n        x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)\n\n        return x_out.type_as(x)\n\ndef _get_projections(attn: \"Flux2Attention\", hidden_states, encoder_hidden_states=None):\n    query = attn.to_q(hidden_states)\n    key = attn.to_k(hidden_states)\n    value = attn.to_v(hidden_states)\n\n    encoder_query = encoder_key = encoder_value = None\n    if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:\n        encoder_query = attn.add_q_proj(encoder_hidden_states)\n        encoder_key = attn.add_k_proj(encoder_hidden_states)\n        encoder_value = attn.add_v_proj(encoder_hidden_states)\n\n    return query, key, value, encoder_query, encoder_key, encoder_value\n\n\ndef _get_fused_projections(attn: \"Flux2Attention\", hidden_states, encoder_hidden_states=None):\n    query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)\n\n    encoder_query = encoder_key = encoder_value = (None,)\n    if encoder_hidden_states is not None and hasattr(attn, \"to_added_qkv\"):\n        encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)\n\n    return query, key, value, encoder_query, encoder_key, encoder_value\n\n\ndef _get_qkv_projections(attn: \"Flux2Attention\", hidden_states, encoder_hidden_states=None):\n    return _get_projections(attn, hidden_states, encoder_hidden_states)\n\n\nclass Flux2SwiGLU(nn.Module):\n    \"\"\"\n    Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection\n    layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.gate_fn = nn.SiLU()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x1, x2 = x.chunk(2, dim=-1)\n        x = self.gate_fn(x1) * x2\n        return x\n\n\nclass Flux2FeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: float = 3.0,\n        inner_dim: Optional[int] = None,\n        bias: bool = False,\n    ):\n        super().__init__()\n        if inner_dim is None:\n            inner_dim = int(dim * mult)\n        dim_out = dim_out or dim\n\n        # Flux2SwiGLU will reduce the dimension by half\n        self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)\n        self.act_fn = Flux2SwiGLU()\n        self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.linear_in(x)\n        x = self.act_fn(x)\n        x = self.linear_out(x)\n        return x\n\n\nclass Flux2AttnProcessor:\n    _attention_backend = None\n    _parallel_config = None\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(f\"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.\")\n\n    def __call__(\n        self,\n        attn: \"Flux2Attention\",\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(\n            attn, hidden_states, encoder_hidden_states\n        )\n\n        query = query.unflatten(-1, (attn.heads, -1))\n        key = key.unflatten(-1, (attn.heads, -1))\n        value = value.unflatten(-1, (attn.heads, -1))\n\n        query = attn.norm_q(query)\n        key = attn.norm_k(key)\n\n        if attn.added_kv_proj_dim is not None:\n            encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))\n            encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))\n            encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))\n\n            encoder_query = attn.norm_added_q(encoder_query)\n            encoder_key = attn.norm_added_k(encoder_key)\n\n            query = torch.cat([encoder_query, query], dim=1)\n            key = torch.cat([encoder_key, key], dim=1)\n            value = torch.cat([encoder_value, value], dim=1)\n\n        if image_rotary_emb is not None:\n            query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)\n            key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)\n\n        query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)\n        hidden_states = attention_forward(\n            query,\n            key,\n            value,\n            q_pattern=\"b s n d\", k_pattern=\"b s n d\", v_pattern=\"b s n d\", out_pattern=\"b s n d\",\n        )\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(\n                [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1\n            )\n            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n\n        hidden_states = attn.to_out[0](hidden_states)\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if encoder_hidden_states is not None:\n            return hidden_states, encoder_hidden_states\n        else:\n            return hidden_states\n\n\nclass Flux2Attention(torch.nn.Module):\n    _default_processor_cls = Flux2AttnProcessor\n    _available_processors = [Flux2AttnProcessor]\n\n    def __init__(\n        self,\n        query_dim: int,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        added_kv_proj_dim: Optional[int] = None,\n        added_proj_bias: Optional[bool] = True,\n        out_bias: bool = True,\n        eps: float = 1e-5,\n        out_dim: int = None,\n        elementwise_affine: bool = True,\n        processor=None,\n    ):\n        super().__init__()\n\n        self.head_dim = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * heads\n        self.query_dim = query_dim\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.heads = out_dim // dim_head if out_dim is not None else heads\n\n        self.use_bias = bias\n        self.dropout = dropout\n\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.added_proj_bias = added_proj_bias\n\n        self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)\n        self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)\n        self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)\n\n        # QK Norm\n        self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n        self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n\n        self.to_out = torch.nn.ModuleList([])\n        self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))\n        self.to_out.append(torch.nn.Dropout(dropout))\n\n        if added_kv_proj_dim is not None:\n            self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)\n            self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)\n            self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)\n            self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)\n            self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)\n            self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)\n\n        if processor is None:\n            processor = self._default_processor_cls()\n        self.processor = processor\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())\n        kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}\n        return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)\n\n\nclass Flux2ParallelSelfAttnProcessor:\n    _attention_backend = None\n    _parallel_config = None\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(f\"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.\")\n\n    def __call__(\n        self,\n        attn: \"Flux2ParallelSelfAttention\",\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        # Parallel in (QKV + MLP in) projection\n        hidden_states = attn.to_qkv_mlp_proj(hidden_states)\n        qkv, mlp_hidden_states = torch.split(\n            hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1\n        )\n\n        # Handle the attention logic\n        query, key, value = qkv.chunk(3, dim=-1)\n\n        query = query.unflatten(-1, (attn.heads, -1))\n        key = key.unflatten(-1, (attn.heads, -1))\n        value = value.unflatten(-1, (attn.heads, -1))\n\n        query = attn.norm_q(query)\n        key = attn.norm_k(key)\n\n        if image_rotary_emb is not None:\n            query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)\n            key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)\n\n        query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)\n        hidden_states = attention_forward(\n            query,\n            key,\n            value,\n            q_pattern=\"b s n d\", k_pattern=\"b s n d\", v_pattern=\"b s n d\", out_pattern=\"b s n d\",\n        )\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # Handle the feedforward (FF) logic\n        mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)\n\n        # Concatenate and parallel output projection\n        hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)\n        hidden_states = attn.to_out(hidden_states)\n\n        return hidden_states\n\n\nclass Flux2ParallelSelfAttention(torch.nn.Module):\n    \"\"\"\n    Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.\n\n    This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)\n    input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B\n    paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.\n    \"\"\"\n\n    _default_processor_cls = Flux2ParallelSelfAttnProcessor\n    _available_processors = [Flux2ParallelSelfAttnProcessor]\n    # Does not support QKV fusion as the QKV projections are always fused\n    _supports_qkv_fusion = False\n\n    def __init__(\n        self,\n        query_dim: int,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        out_bias: bool = True,\n        eps: float = 1e-5,\n        out_dim: int = None,\n        elementwise_affine: bool = True,\n        mlp_ratio: float = 4.0,\n        mlp_mult_factor: int = 2,\n        processor=None,\n    ):\n        super().__init__()\n\n        self.head_dim = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * heads\n        self.query_dim = query_dim\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.heads = out_dim // dim_head if out_dim is not None else heads\n\n        self.use_bias = bias\n        self.dropout = dropout\n\n        self.mlp_ratio = mlp_ratio\n        self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)\n        self.mlp_mult_factor = mlp_mult_factor\n\n        # Fused QKV projections + MLP input projection\n        self.to_qkv_mlp_proj = torch.nn.Linear(\n            self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias\n        )\n        self.mlp_act_fn = Flux2SwiGLU()\n\n        # QK Norm\n        self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n        self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n\n        # Fused attention output projection + MLP output projection\n        self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)\n\n        if processor is None:\n            processor = self._default_processor_cls()\n        self.processor = processor\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())\n        kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}\n        return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)\n\n\nclass Flux2SingleTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        mlp_ratio: float = 3.0,\n        eps: float = 1e-6,\n        bias: bool = False,\n    ):\n        super().__init__()\n\n        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n\n        # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this\n        # is often called a \"parallel\" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)\n        # for a visual depiction of this type of transformer block.\n        self.attn = Flux2ParallelSelfAttention(\n            query_dim=dim,\n            dim_head=attention_head_dim,\n            heads=num_attention_heads,\n            out_dim=dim,\n            bias=bias,\n            out_bias=bias,\n            eps=eps,\n            mlp_ratio=mlp_ratio,\n            mlp_mult_factor=2,\n            processor=Flux2ParallelSelfAttnProcessor(),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor],\n        temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        split_hidden_states: bool = False,\n        text_seq_len: Optional[int] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already\n        # concatenated\n        if encoder_hidden_states is not None:\n            text_seq_len = encoder_hidden_states.shape[1]\n            hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        mod_shift, mod_scale, mod_gate = temb_mod_params\n\n        norm_hidden_states = self.norm(hidden_states)\n        norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift\n\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        attn_output = self.attn(\n            hidden_states=norm_hidden_states,\n            image_rotary_emb=image_rotary_emb,\n            **joint_attention_kwargs,\n        )\n\n        hidden_states = hidden_states + mod_gate * attn_output\n        if hidden_states.dtype == torch.float16:\n            hidden_states = hidden_states.clip(-65504, 65504)\n\n        if split_hidden_states:\n            encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]\n            return encoder_hidden_states, hidden_states\n        else:\n            return hidden_states\n\n\nclass Flux2TransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        mlp_ratio: float = 3.0,\n        eps: float = 1e-6,\n        bias: bool = False,\n    ):\n        super().__init__()\n        self.mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n\n        self.attn = Flux2Attention(\n            query_dim=dim,\n            added_kv_proj_dim=dim,\n            dim_head=attention_head_dim,\n            heads=num_attention_heads,\n            out_dim=dim,\n            bias=bias,\n            added_proj_bias=bias,\n            out_bias=bias,\n            eps=eps,\n            processor=Flux2AttnProcessor(),\n        )\n\n        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)\n\n        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],\n        temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],\n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        joint_attention_kwargs = joint_attention_kwargs or {}\n\n        # Modulation parameters shape: [1, 1, self.dim]\n        (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img\n        (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt\n\n        # Img stream\n        norm_hidden_states = self.norm1(hidden_states)\n        norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa\n\n        # Conditioning txt stream\n        norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)\n        norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa\n\n        # Attention on concatenated img + txt stream\n        attention_outputs = self.attn(\n            hidden_states=norm_hidden_states,\n            encoder_hidden_states=norm_encoder_hidden_states,\n            image_rotary_emb=image_rotary_emb,\n            **joint_attention_kwargs,\n        )\n\n        attn_output, context_attn_output = attention_outputs\n\n        # Process attention outputs for the image stream (`hidden_states`).\n        attn_output = gate_msa * attn_output\n        hidden_states = hidden_states + attn_output\n\n        norm_hidden_states = self.norm2(hidden_states)\n        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp\n\n        ff_output = self.ff(norm_hidden_states)\n        hidden_states = hidden_states + gate_mlp * ff_output\n\n        # Process attention outputs for the text stream (`encoder_hidden_states`).\n        context_attn_output = c_gate_msa * context_attn_output\n        encoder_hidden_states = encoder_hidden_states + context_attn_output\n\n        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)\n        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp\n\n        context_ff_output = self.ff_context(norm_encoder_hidden_states)\n        encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output\n        if encoder_hidden_states.dtype == torch.float16:\n            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)\n\n        return encoder_hidden_states, hidden_states\n\n\nclass Flux2PosEmbed(nn.Module):\n    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11\n    def __init__(self, theta: int, axes_dim: List[int]):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n\n    def forward(self, ids: torch.Tensor) -> torch.Tensor:\n        # Expected ids shape: [S, len(self.axes_dim)]\n        cos_out = []\n        sin_out = []\n        pos = ids.float()\n        is_mps = ids.device.type == \"mps\"\n        is_npu = ids.device.type == \"npu\"\n        freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64\n        # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]\n        for i in range(len(self.axes_dim)):\n            cos, sin = get_1d_rotary_pos_embed(\n                self.axes_dim[i],\n                pos[..., i],\n                theta=self.theta,\n                repeat_interleave_real=True,\n                use_real=True,\n                freqs_dtype=freqs_dtype,\n            )\n            cos_out.append(cos)\n            sin_out.append(sin)\n        freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)\n        freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)\n        return freqs_cos, freqs_sin\n\n\nclass Flux2TimestepGuidanceEmbeddings(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 256,\n        embedding_dim: int = 6144,\n        bias: bool = False,\n        guidance_embeds: bool = True,\n    ):\n        super().__init__()\n\n        self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias\n        )\n\n        if guidance_embeds:\n            self.guidance_embedder = TimestepEmbedding(\n                in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias\n            )\n        else:\n            self.guidance_embedder = None\n\n    def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:\n        timesteps_proj = self.time_proj(timestep)\n        timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype))  # (N, D)\n\n        if guidance is not None and self.guidance_embedder is not None:\n            guidance_proj = self.time_proj(guidance)\n            guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype))  # (N, D)\n            time_guidance_emb = timesteps_emb + guidance_emb\n            return time_guidance_emb\n        else:\n            return timesteps_emb\n\n\nclass Flux2Modulation(nn.Module):\n    def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):\n        super().__init__()\n        self.mod_param_sets = mod_param_sets\n\n        self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)\n        self.act_fn = nn.SiLU()\n\n    def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:\n        mod = self.act_fn(temb)\n        mod = self.linear(mod)\n\n        if mod.ndim == 2:\n            mod = mod.unsqueeze(1)\n        mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)\n        # Return tuple of 3-tuples of modulation params shift/scale/gate\n        return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))\n\n\nclass Flux2DiT(torch.nn.Module):\n    def __init__(\n        self,\n        patch_size: int = 1,\n        in_channels: int = 128,\n        out_channels: Optional[int] = None,\n        num_layers: int = 8,\n        num_single_layers: int = 48,\n        attention_head_dim: int = 128,\n        num_attention_heads: int = 48,\n        joint_attention_dim: int = 15360,\n        timestep_guidance_channels: int = 256,\n        mlp_ratio: float = 3.0,\n        axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),\n        rope_theta: int = 2000,\n        eps: float = 1e-6,\n        guidance_embeds: bool = True,\n    ):\n        super().__init__()\n        self.out_channels = out_channels or in_channels\n        self.inner_dim = num_attention_heads * attention_head_dim\n\n        # 1. Sinusoidal positional embedding for RoPE on image and text tokens\n        self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)\n\n        # 2. Combined timestep + guidance embedding\n        self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(\n            in_channels=timestep_guidance_channels,\n            embedding_dim=self.inner_dim,\n            bias=False,\n            guidance_embeds=guidance_embeds,\n        )\n\n        # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)\n        # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks\n        self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)\n        self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)\n        # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream\n        self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)\n\n        # 4. Input projections\n        self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)\n        self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)\n\n        # 5. Double Stream Transformer Blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                Flux2TransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    mlp_ratio=mlp_ratio,\n                    eps=eps,\n                    bias=False,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        # 6. Single Stream Transformer Blocks\n        self.single_transformer_blocks = nn.ModuleList(\n            [\n                Flux2SingleTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    mlp_ratio=mlp_ratio,\n                    eps=eps,\n                    bias=False,\n                )\n                for _ in range(num_single_layers)\n            ]\n        )\n\n        # 7. Output layers\n        self.norm_out = AdaLayerNormContinuous(\n            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False\n        )\n        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        img_ids: torch.Tensor = None,\n        txt_ids: torch.Tensor = None,\n        guidance: torch.Tensor = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        # 0. Handle input arguments\n        if joint_attention_kwargs is not None:\n            joint_attention_kwargs = joint_attention_kwargs.copy()\n            lora_scale = joint_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        num_txt_tokens = encoder_hidden_states.shape[1]\n\n        # 1. Calculate timestep embedding and modulation parameters\n        timestep = timestep.to(hidden_states.dtype) * 1000\n\n        if guidance is not None:\n            guidance = guidance.to(hidden_states.dtype) * 1000\n\n        temb = self.time_guidance_embed(timestep, guidance)\n\n        double_stream_mod_img = self.double_stream_modulation_img(temb)\n        double_stream_mod_txt = self.double_stream_modulation_txt(temb)\n        single_stream_mod = self.single_stream_modulation(temb)[0]\n\n        # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)\n        hidden_states = self.x_embedder(hidden_states)\n        encoder_hidden_states = self.context_embedder(encoder_hidden_states)\n\n        # 3. Calculate RoPE embeddings from image and text tokens\n        # NOTE: the below logic means that we can't support batched inference with images of different resolutions or\n        # text prompts of differents lengths. Is this a use case we want to support?\n        if img_ids.ndim == 3:\n            img_ids = img_ids[0]\n        if txt_ids.ndim == 3:\n            txt_ids = txt_ids[0]\n\n        image_rotary_emb = self.pos_embed(img_ids)\n        text_rotary_emb = self.pos_embed(txt_ids)\n        concat_rotary_emb = (\n            torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),\n            torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),\n        )\n\n        # 4. Double Stream Transformer Blocks\n        for index_block, block in enumerate(self.transformer_blocks):\n            encoder_hidden_states, hidden_states = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                temb_mod_params_img=double_stream_mod_img,\n                temb_mod_params_txt=double_stream_mod_txt,\n                image_rotary_emb=concat_rotary_emb,\n                joint_attention_kwargs=joint_attention_kwargs,\n            )\n        # Concatenate text and image streams for single-block inference\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        # 5. Single Stream Transformer Blocks\n        for index_block, block in enumerate(self.single_transformer_blocks):\n            hidden_states = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                hidden_states=hidden_states,\n                encoder_hidden_states=None,\n                temb_mod_params=single_stream_mod,\n                image_rotary_emb=concat_rotary_emb,\n                joint_attention_kwargs=joint_attention_kwargs,\n            )\n        # Remove text tokens from concatenated stream\n        hidden_states = hidden_states[:, num_txt_tokens:, ...]\n\n        # 6. Output layers\n        hidden_states = self.norm_out(hidden_states, temb)\n        output = self.proj_out(hidden_states)\n\n        return output\n"
  },
  {
    "path": "diffsynth/models/flux2_text_encoder.py",
    "content": "from transformers import Mistral3ForConditionalGeneration, Mistral3Config\n\n\nclass Flux2TextEncoder(Mistral3ForConditionalGeneration):\n    def __init__(self):\n        config = Mistral3Config(**{\n            \"architectures\": [\n                \"Mistral3ForConditionalGeneration\"\n            ],\n            \"dtype\": \"bfloat16\",\n            \"image_token_index\": 10,\n            \"model_type\": \"mistral3\",\n            \"multimodal_projector_bias\": False,\n            \"projector_hidden_act\": \"gelu\",\n            \"spatial_merge_size\": 2,\n            \"text_config\": {\n                \"attention_dropout\": 0.0,\n                \"dtype\": \"bfloat16\",\n                \"head_dim\": 128,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 5120,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 32768,\n                \"max_position_embeddings\": 131072,\n                \"model_type\": \"mistral\",\n                \"num_attention_heads\": 32,\n                \"num_hidden_layers\": 40,\n                \"num_key_value_heads\": 8,\n                \"rms_norm_eps\": 1e-05,\n                \"rope_theta\": 1000000000.0,\n                \"sliding_window\": None,\n                \"use_cache\": True,\n                \"vocab_size\": 131072\n            },\n            \"transformers_version\": \"4.57.1\",\n            \"vision_config\": {\n                \"attention_dropout\": 0.0,\n                \"dtype\": \"bfloat16\",\n                \"head_dim\": 64,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 1024,\n                \"image_size\": 1540,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 4096,\n                \"model_type\": \"pixtral\",\n                \"num_attention_heads\": 16,\n                \"num_channels\": 3,\n                \"num_hidden_layers\": 24,\n                \"patch_size\": 14,\n                \"rope_theta\": 10000.0\n            },\n            \"vision_feature_layer\": -1\n        })\n        super().__init__(config)\n    \n    def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):\n        return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)\n\n"
  },
  {
    "path": "diffsynth/models/flux2_vae.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import Dict, Optional, Tuple, Union, Callable\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nimport torch.nn.functional as F\nimport inspect\n\nACT2CLS = {\n    \"swish\": nn.SiLU,\n    \"silu\": nn.SiLU,\n    \"mish\": nn.Mish,\n    \"gelu\": nn.GELU,\n    \"relu\": nn.ReLU,\n}\n\ndef get_activation(act_fn: str) -> nn.Module:\n    \"\"\"Helper function to get activation function from string.\n\n    Args:\n        act_fn (str): Name of activation function.\n\n    Returns:\n        nn.Module: Activation function.\n    \"\"\"\n\n    act_fn = act_fn.lower()\n    if act_fn in ACT2CLS:\n        return ACT2CLS[act_fn]()\n    else:\n        raise ValueError(f\"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}\")\n\nclass ResnetBlock2D(nn.Module):\n    r\"\"\"\n    A Resnet block.\n\n    Parameters:\n        in_channels (`int`): The number of channels in the input.\n        out_channels (`int`, *optional*, default to be `None`):\n            The number of output channels for the first conv2d layer. If None, same as `in_channels`.\n        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.\n        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.\n        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.\n        groups_out (`int`, *optional*, default to None):\n            The number of groups to use for the second normalization layer. if set to None, same as `groups`.\n        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.\n        non_linearity (`str`, *optional*, default to `\"swish\"`): the activation function to use.\n        time_embedding_norm (`str`, *optional*, default to `\"default\"` ): Time scale shift config.\n            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose \"scale_shift\" for a\n            stronger conditioning with scale and shift.\n        kernel (`torch.Tensor`, optional, default to None): FIR filter, see\n            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].\n        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.\n        use_in_shortcut (`bool`, *optional*, default to `True`):\n            If `True`, add a 1x1 nn.conv2d layer for skip-connection.\n        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.\n        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.\n        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the\n            `conv_shortcut` output.\n        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.\n            If None, same as `out_channels`.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        conv_shortcut: bool = False,\n        dropout: float = 0.0,\n        temb_channels: int = 512,\n        groups: int = 32,\n        groups_out: Optional[int] = None,\n        pre_norm: bool = True,\n        eps: float = 1e-6,\n        non_linearity: str = \"swish\",\n        skip_time_act: bool = False,\n        time_embedding_norm: str = \"default\",  # default, scale_shift,\n        kernel: Optional[torch.Tensor] = None,\n        output_scale_factor: float = 1.0,\n        use_in_shortcut: Optional[bool] = None,\n        up: bool = False,\n        down: bool = False,\n        conv_shortcut_bias: bool = True,\n        conv_2d_out_channels: Optional[int] = None,\n    ):\n        super().__init__()\n        if time_embedding_norm == \"ada_group\":\n            raise ValueError(\n                \"This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead\",\n            )\n        if time_embedding_norm == \"spatial\":\n            raise ValueError(\n                \"This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead\",\n            )\n\n        self.pre_norm = True\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n        self.up = up\n        self.down = down\n        self.output_scale_factor = output_scale_factor\n        self.time_embedding_norm = time_embedding_norm\n        self.skip_time_act = skip_time_act\n\n        if groups_out is None:\n            groups_out = groups\n\n        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        if temb_channels is not None:\n            if self.time_embedding_norm == \"default\":\n                self.time_emb_proj = nn.Linear(temb_channels, out_channels)\n            elif self.time_embedding_norm == \"scale_shift\":\n                self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)\n            else:\n                raise ValueError(f\"unknown time_embedding_norm : {self.time_embedding_norm} \")\n        else:\n            self.time_emb_proj = None\n\n        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)\n\n        self.dropout = torch.nn.Dropout(dropout)\n        conv_2d_out_channels = conv_2d_out_channels or out_channels\n        self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)\n\n        self.nonlinearity = get_activation(non_linearity)\n\n        self.upsample = self.downsample = None\n        if self.up:\n            if kernel == \"fir\":\n                fir_kernel = (1, 3, 3, 1)\n                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)\n            elif kernel == \"sde_vp\":\n                self.upsample = partial(F.interpolate, scale_factor=2.0, mode=\"nearest\")\n            else:\n                self.upsample = Upsample2D(in_channels, use_conv=False)\n        elif self.down:\n            if kernel == \"fir\":\n                fir_kernel = (1, 3, 3, 1)\n                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)\n            elif kernel == \"sde_vp\":\n                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)\n            else:\n                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name=\"op\")\n\n        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut\n\n        self.conv_shortcut = None\n        if self.use_in_shortcut:\n            self.conv_shortcut = nn.Conv2d(\n                in_channels,\n                conv_2d_out_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=conv_shortcut_bias,\n            )\n\n    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        hidden_states = input_tensor\n\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n\n        if self.upsample is not None:\n            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984\n            if hidden_states.shape[0] >= 64:\n                input_tensor = input_tensor.contiguous()\n                hidden_states = hidden_states.contiguous()\n            input_tensor = self.upsample(input_tensor)\n            hidden_states = self.upsample(hidden_states)\n        elif self.downsample is not None:\n            input_tensor = self.downsample(input_tensor)\n            hidden_states = self.downsample(hidden_states)\n\n        hidden_states = self.conv1(hidden_states)\n\n        if self.time_emb_proj is not None:\n            if not self.skip_time_act:\n                temb = self.nonlinearity(temb)\n            temb = self.time_emb_proj(temb)[:, :, None, None]\n\n        if self.time_embedding_norm == \"default\":\n            if temb is not None:\n                hidden_states = hidden_states + temb\n            hidden_states = self.norm2(hidden_states)\n        elif self.time_embedding_norm == \"scale_shift\":\n            if temb is None:\n                raise ValueError(\n                    f\" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}\"\n                )\n            time_scale, time_shift = torch.chunk(temb, 2, dim=1)\n            hidden_states = self.norm2(hidden_states)\n            hidden_states = hidden_states * (1 + time_scale) + time_shift\n        else:\n            hidden_states = self.norm2(hidden_states)\n\n        hidden_states = self.nonlinearity(hidden_states)\n\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            input_tensor = self.conv_shortcut(input_tensor.contiguous())\n\n        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor\n\n        return output_tensor\n\nclass Downsample2D(nn.Module):\n    \"\"\"A 2D downsampling layer with an optional convolution.\n\n    Parameters:\n        channels (`int`):\n            number of channels in the inputs and outputs.\n        use_conv (`bool`, default `False`):\n            option to use a convolution.\n        out_channels (`int`, optional):\n            number of output channels. Defaults to `channels`.\n        padding (`int`, default `1`):\n            padding for the convolution.\n        name (`str`, default `conv`):\n            name of the downsampling 2D layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        use_conv: bool = False,\n        out_channels: Optional[int] = None,\n        padding: int = 1,\n        name: str = \"conv\",\n        kernel_size=3,\n        norm_type=None,\n        eps=None,\n        elementwise_affine=None,\n        bias=True,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.padding = padding\n        stride = 2\n        self.name = name\n\n        if norm_type == \"ln_norm\":\n            self.norm = nn.LayerNorm(channels, eps, elementwise_affine)\n        elif norm_type == \"rms_norm\":\n            self.norm = RMSNorm(channels, eps, elementwise_affine)\n        elif norm_type is None:\n            self.norm = None\n        else:\n            raise ValueError(f\"unknown norm_type: {norm_type}\")\n\n        if use_conv:\n            conv = nn.Conv2d(\n                self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias\n            )\n        else:\n            assert self.channels == self.out_channels\n            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)\n\n        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed\n        if name == \"conv\":\n            self.Conv2d_0 = conv\n            self.conv = conv\n        elif name == \"Conv2d_0\":\n            self.conv = conv\n        else:\n            self.conv = conv\n\n    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n        assert hidden_states.shape[1] == self.channels\n\n        if self.norm is not None:\n            hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n\n        if self.use_conv and self.padding == 0:\n            pad = (0, 1, 0, 1)\n            hidden_states = F.pad(hidden_states, pad, mode=\"constant\", value=0)\n\n        assert hidden_states.shape[1] == self.channels\n\n        hidden_states = self.conv(hidden_states)\n\n        return hidden_states\n\nclass Upsample2D(nn.Module):\n    \"\"\"A 2D upsampling layer with an optional convolution.\n\n    Parameters:\n        channels (`int`):\n            number of channels in the inputs and outputs.\n        use_conv (`bool`, default `False`):\n            option to use a convolution.\n        use_conv_transpose (`bool`, default `False`):\n            option to use a convolution transpose.\n        out_channels (`int`, optional):\n            number of output channels. Defaults to `channels`.\n        name (`str`, default `conv`):\n            name of the upsampling 2D layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        use_conv: bool = False,\n        use_conv_transpose: bool = False,\n        out_channels: Optional[int] = None,\n        name: str = \"conv\",\n        kernel_size: Optional[int] = None,\n        padding=1,\n        norm_type=None,\n        eps=None,\n        elementwise_affine=None,\n        bias=True,\n        interpolate=True,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_conv_transpose = use_conv_transpose\n        self.name = name\n        self.interpolate = interpolate\n\n        if norm_type == \"ln_norm\":\n            self.norm = nn.LayerNorm(channels, eps, elementwise_affine)\n        elif norm_type == \"rms_norm\":\n            self.norm = RMSNorm(channels, eps, elementwise_affine)\n        elif norm_type is None:\n            self.norm = None\n        else:\n            raise ValueError(f\"unknown norm_type: {norm_type}\")\n\n        conv = None\n        if use_conv_transpose:\n            if kernel_size is None:\n                kernel_size = 4\n            conv = nn.ConvTranspose2d(\n                channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias\n            )\n        elif use_conv:\n            if kernel_size is None:\n                kernel_size = 3\n            conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)\n\n        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed\n        if name == \"conv\":\n            self.conv = conv\n        else:\n            self.Conv2d_0 = conv\n\n    def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        assert hidden_states.shape[1] == self.channels\n\n        if self.norm is not None:\n            hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n\n        if self.use_conv_transpose:\n            return self.conv(hidden_states)\n\n        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1\n        # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767\n        dtype = hidden_states.dtype\n        if dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(torch.float32)\n\n        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984\n        if hidden_states.shape[0] >= 64:\n            hidden_states = hidden_states.contiguous()\n\n        # if `output_size` is passed we force the interpolation output\n        # size and do not make use of `scale_factor=2`\n        if self.interpolate:\n            # upsample_nearest_nhwc also fails when the number of output elements is large\n            # https://github.com/pytorch/pytorch/issues/141831\n            scale_factor = (\n                2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])\n            )\n            if hidden_states.numel() * scale_factor > pow(2, 31):\n                hidden_states = hidden_states.contiguous()\n\n            if output_size is None:\n                hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode=\"nearest\")\n            else:\n                hidden_states = F.interpolate(hidden_states, size=output_size, mode=\"nearest\")\n\n        # Cast back to original dtype\n        if dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(dtype)\n\n        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed\n        if self.use_conv:\n            if self.name == \"conv\":\n                hidden_states = self.conv(hidden_states)\n            else:\n                hidden_states = self.Conv2d_0(hidden_states)\n\n        return hidden_states\n\n\nclass Attention(nn.Module):\n    r\"\"\"\n    A cross attention layer.\n\n    Parameters:\n        query_dim (`int`):\n            The number of channels in the query.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.\n        heads (`int`,  *optional*, defaults to 8):\n            The number of heads to use for multi-head attention.\n        kv_heads (`int`,  *optional*, defaults to `None`):\n            The number of key and value heads to use for multi-head attention. Defaults to `heads`. If\n            `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi\n            Query Attention (MQA) otherwise GQA is used.\n        dim_head (`int`,  *optional*, defaults to 64):\n            The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability to use.\n        bias (`bool`, *optional*, defaults to False):\n            Set to `True` for the query, key, and value linear layers to contain a bias parameter.\n        upcast_attention (`bool`, *optional*, defaults to False):\n            Set to `True` to upcast the attention computation to `float32`.\n        upcast_softmax (`bool`, *optional*, defaults to False):\n            Set to `True` to upcast the softmax computation to `float32`.\n        cross_attention_norm (`str`, *optional*, defaults to `None`):\n            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.\n        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use for the group norm in the cross attention.\n        added_kv_proj_dim (`int`, *optional*, defaults to `None`):\n            The number of channels to use for the added key and value projections. If `None`, no projection is used.\n        norm_num_groups (`int`, *optional*, defaults to `None`):\n            The number of groups to use for the group norm in the attention.\n        spatial_norm_dim (`int`, *optional*, defaults to `None`):\n            The number of channels to use for the spatial normalization.\n        out_bias (`bool`, *optional*, defaults to `True`):\n            Set to `True` to use a bias in the output linear layer.\n        scale_qk (`bool`, *optional*, defaults to `True`):\n            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.\n        only_cross_attention (`bool`, *optional*, defaults to `False`):\n            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if\n            `added_kv_proj_dim` is not `None`.\n        eps (`float`, *optional*, defaults to 1e-5):\n            An additional value added to the denominator in group normalization that is used for numerical stability.\n        rescale_output_factor (`float`, *optional*, defaults to 1.0):\n            A factor to rescale the output by dividing it with this value.\n        residual_connection (`bool`, *optional*, defaults to `False`):\n            Set to `True` to add the residual connection to the output.\n        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):\n            Set to `True` if the attention block is loaded from a deprecated state dict.\n        processor (`AttnProcessor`, *optional*, defaults to `None`):\n            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and\n            `AttnProcessor` otherwise.\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        cross_attention_dim: Optional[int] = None,\n        heads: int = 8,\n        kv_heads: Optional[int] = None,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        upcast_attention: bool = False,\n        upcast_softmax: bool = False,\n        cross_attention_norm: Optional[str] = None,\n        cross_attention_norm_num_groups: int = 32,\n        qk_norm: Optional[str] = None,\n        added_kv_proj_dim: Optional[int] = None,\n        added_proj_bias: Optional[bool] = True,\n        norm_num_groups: Optional[int] = None,\n        spatial_norm_dim: Optional[int] = None,\n        out_bias: bool = True,\n        scale_qk: bool = True,\n        only_cross_attention: bool = False,\n        eps: float = 1e-5,\n        rescale_output_factor: float = 1.0,\n        residual_connection: bool = False,\n        _from_deprecated_attn_block: bool = False,\n        processor: Optional[\"AttnProcessor\"] = None,\n        out_dim: int = None,\n        out_context_dim: int = None,\n        context_pre_only=None,\n        pre_only=False,\n        elementwise_affine: bool = True,\n        is_causal: bool = False,\n    ):\n        super().__init__()\n\n        # To prevent circular import.\n        # from .normalization import FP32LayerNorm, LpNorm, RMSNorm\n\n        self.inner_dim = out_dim if out_dim is not None else dim_head * heads\n        self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads\n        self.query_dim = query_dim\n        self.use_bias = bias\n        self.is_cross_attention = cross_attention_dim is not None\n        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim\n        self.upcast_attention = upcast_attention\n        self.upcast_softmax = upcast_softmax\n        self.rescale_output_factor = rescale_output_factor\n        self.residual_connection = residual_connection\n        self.dropout = dropout\n        self.fused_projections = False\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim\n        self.context_pre_only = context_pre_only\n        self.pre_only = pre_only\n        self.is_causal = is_causal\n\n        # we make use of this private variable to know whether this class is loaded\n        # with an deprecated state dict so that we can convert it on the fly\n        self._from_deprecated_attn_block = _from_deprecated_attn_block\n\n        self.scale_qk = scale_qk\n        self.scale = dim_head**-0.5 if self.scale_qk else 1.0\n\n        self.heads = out_dim // dim_head if out_dim is not None else heads\n        # for slice_size > 0 the attention score computation\n        # is split across the batch axis to save memory\n        # You can set slice_size with `set_attention_slice`\n        self.sliceable_head_dim = heads\n\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.only_cross_attention = only_cross_attention\n\n        if self.added_kv_proj_dim is None and self.only_cross_attention:\n            raise ValueError(\n                \"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`.\"\n            )\n\n        if norm_num_groups is not None:\n            self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)\n        else:\n            self.group_norm = None\n\n        if spatial_norm_dim is not None:\n            self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)\n        else:\n            self.spatial_norm = None\n\n        if qk_norm is None:\n            self.norm_q = None\n            self.norm_k = None\n        elif qk_norm == \"layer_norm\":\n            self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n            self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n        elif qk_norm == \"fp32_layer_norm\":\n            self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)\n            self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)\n        elif qk_norm == \"layer_norm_across_heads\":\n            # Lumina applies qk norm across all heads\n            self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)\n            self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)\n        elif qk_norm == \"rms_norm\":\n            self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n            self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n        elif qk_norm == \"rms_norm_across_heads\":\n            # LTX applies qk norm across all heads\n            self.norm_q = RMSNorm(dim_head * heads, eps=eps)\n            self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)\n        elif qk_norm == \"l2\":\n            self.norm_q = LpNorm(p=2, dim=-1, eps=eps)\n            self.norm_k = LpNorm(p=2, dim=-1, eps=eps)\n        else:\n            raise ValueError(\n                f\"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'.\"\n            )\n\n        if cross_attention_norm is None:\n            self.norm_cross = None\n        elif cross_attention_norm == \"layer_norm\":\n            self.norm_cross = nn.LayerNorm(self.cross_attention_dim)\n        elif cross_attention_norm == \"group_norm\":\n            if self.added_kv_proj_dim is not None:\n                # The given `encoder_hidden_states` are initially of shape\n                # (batch_size, seq_len, added_kv_proj_dim) before being projected\n                # to (batch_size, seq_len, cross_attention_dim). The norm is applied\n                # before the projection, so we need to use `added_kv_proj_dim` as\n                # the number of channels for the group norm.\n                norm_cross_num_channels = added_kv_proj_dim\n            else:\n                norm_cross_num_channels = self.cross_attention_dim\n\n            self.norm_cross = nn.GroupNorm(\n                num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True\n            )\n        else:\n            raise ValueError(\n                f\"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'\"\n            )\n\n        self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)\n\n        if not self.only_cross_attention:\n            # only relevant for the `AddedKVProcessor` classes\n            self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)\n            self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)\n        else:\n            self.to_k = None\n            self.to_v = None\n\n        self.added_proj_bias = added_proj_bias\n        if self.added_kv_proj_dim is not None:\n            self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)\n            self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)\n            if self.context_pre_only is not None:\n                self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)\n        else:\n            self.add_q_proj = None\n            self.add_k_proj = None\n            self.add_v_proj = None\n\n        if not self.pre_only:\n            self.to_out = nn.ModuleList([])\n            self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))\n            self.to_out.append(nn.Dropout(dropout))\n        else:\n            self.to_out = None\n\n        if self.context_pre_only is not None and not self.context_pre_only:\n            self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)\n        else:\n            self.to_add_out = None\n\n        if qk_norm is not None and added_kv_proj_dim is not None:\n            if qk_norm == \"layer_norm\":\n                self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n                self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)\n            elif qk_norm == \"fp32_layer_norm\":\n                self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)\n                self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)\n            elif qk_norm == \"rms_norm\":\n                self.norm_added_q = RMSNorm(dim_head, eps=eps)\n                self.norm_added_k = RMSNorm(dim_head, eps=eps)\n            elif qk_norm == \"rms_norm_across_heads\":\n                # Wan applies qk norm across all heads\n                # Wan also doesn't apply a q norm\n                self.norm_added_q = None\n                self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)\n            else:\n                raise ValueError(\n                    f\"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`\"\n                )\n        else:\n            self.norm_added_q = None\n            self.norm_added_k = None\n\n        # set attention processor\n        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n        if processor is None:\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n        self.set_processor(processor)\n\n    def set_use_xla_flash_attention(\n        self,\n        use_xla_flash_attention: bool,\n        partition_spec: Optional[Tuple[Optional[str], ...]] = None,\n        is_flux=False,\n    ) -> None:\n        r\"\"\"\n        Set whether to use xla flash attention from `torch_xla` or not.\n\n        Args:\n            use_xla_flash_attention (`bool`):\n                Whether to use pallas flash attention kernel from `torch_xla` or not.\n            partition_spec (`Tuple[]`, *optional*):\n                Specify the partition specification if using SPMD. Otherwise None.\n        \"\"\"\n        if use_xla_flash_attention:\n            if not is_torch_xla_available:\n                raise \"torch_xla is not available\"\n            elif is_torch_xla_version(\"<\", \"2.3\"):\n                raise \"flash attention pallas kernel is supported from torch_xla version 2.3\"\n            elif is_spmd() and is_torch_xla_version(\"<\", \"2.4\"):\n                raise \"flash attention pallas kernel using SPMD is supported from torch_xla version 2.4\"\n            else:\n                if is_flux:\n                    processor = XLAFluxFlashAttnProcessor2_0(partition_spec)\n                else:\n                    processor = XLAFlashAttnProcessor2_0(partition_spec)\n        else:\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n        self.set_processor(processor)\n\n    def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:\n        r\"\"\"\n        Set whether to use npu flash attention from `torch_npu` or not.\n\n        \"\"\"\n        if use_npu_flash_attention:\n            processor = AttnProcessorNPU()\n        else:\n            # set attention processor\n            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n        self.set_processor(processor)\n\n    def set_use_memory_efficient_attention_xformers(\n        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None\n    ) -> None:\n        r\"\"\"\n        Set whether to use memory efficient attention from `xformers` or not.\n\n        Args:\n            use_memory_efficient_attention_xformers (`bool`):\n                Whether to use memory efficient attention from `xformers` or not.\n            attention_op (`Callable`, *optional*):\n                The attention operation to use. Defaults to `None` which uses the default attention operation from\n                `xformers`.\n        \"\"\"\n        is_custom_diffusion = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),\n        )\n        is_added_kv_processor = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (\n                AttnAddedKVProcessor,\n                AttnAddedKVProcessor2_0,\n                SlicedAttnAddedKVProcessor,\n                XFormersAttnAddedKVProcessor,\n            ),\n        )\n        is_ip_adapter = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),\n        )\n        is_joint_processor = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (\n                JointAttnProcessor2_0,\n                XFormersJointAttnProcessor,\n            ),\n        )\n\n        if use_memory_efficient_attention_xformers:\n            if is_added_kv_processor and is_custom_diffusion:\n                raise NotImplementedError(\n                    f\"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}\"\n                )\n            if not is_xformers_available():\n                raise ModuleNotFoundError(\n                    (\n                        \"Refer to https://github.com/facebookresearch/xformers for more information on how to install\"\n                        \" xformers\"\n                    ),\n                    name=\"xformers\",\n                )\n            elif not torch.cuda.is_available():\n                raise ValueError(\n                    \"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is\"\n                    \" only available for GPU \"\n                )\n            else:\n                try:\n                    # Make sure we can run the memory efficient attention\n                    dtype = None\n                    if attention_op is not None:\n                        op_fw, op_bw = attention_op\n                        dtype, *_ = op_fw.SUPPORTED_DTYPES\n                    q = torch.randn((1, 2, 40), device=\"cuda\", dtype=dtype)\n                    _ = xformers.ops.memory_efficient_attention(q, q, q)\n                except Exception as e:\n                    raise e\n\n            if is_custom_diffusion:\n                processor = CustomDiffusionXFormersAttnProcessor(\n                    train_kv=self.processor.train_kv,\n                    train_q_out=self.processor.train_q_out,\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    attention_op=attention_op,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_custom_diffusion\"):\n                    processor.to(self.processor.to_k_custom_diffusion.weight.device)\n            elif is_added_kv_processor:\n                # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP\n                # which uses this type of cross attention ONLY because the attention mask of format\n                # [0, ..., -10.000, ..., 0, ...,] is not supported\n                # throw warning\n                logger.info(\n                    \"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation.\"\n                )\n                processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)\n            elif is_ip_adapter:\n                processor = IPAdapterXFormersAttnProcessor(\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    num_tokens=self.processor.num_tokens,\n                    scale=self.processor.scale,\n                    attention_op=attention_op,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_ip\"):\n                    processor.to(\n                        device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype\n                    )\n            elif is_joint_processor:\n                processor = XFormersJointAttnProcessor(attention_op=attention_op)\n            else:\n                processor = XFormersAttnProcessor(attention_op=attention_op)\n        else:\n            if is_custom_diffusion:\n                attn_processor_class = (\n                    CustomDiffusionAttnProcessor2_0\n                    if hasattr(F, \"scaled_dot_product_attention\")\n                    else CustomDiffusionAttnProcessor\n                )\n                processor = attn_processor_class(\n                    train_kv=self.processor.train_kv,\n                    train_q_out=self.processor.train_q_out,\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_custom_diffusion\"):\n                    processor.to(self.processor.to_k_custom_diffusion.weight.device)\n            elif is_ip_adapter:\n                processor = IPAdapterAttnProcessor2_0(\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    num_tokens=self.processor.num_tokens,\n                    scale=self.processor.scale,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_ip\"):\n                    processor.to(\n                        device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype\n                    )\n            else:\n                # set attention processor\n                # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n                # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n                # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n                processor = (\n                    AttnProcessor2_0()\n                    if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk\n                    else AttnProcessor()\n                )\n\n        self.set_processor(processor)\n\n    def set_attention_slice(self, slice_size: int) -> None:\n        r\"\"\"\n        Set the slice size for attention computation.\n\n        Args:\n            slice_size (`int`):\n                The slice size for attention computation.\n        \"\"\"\n        if slice_size is not None and slice_size > self.sliceable_head_dim:\n            raise ValueError(f\"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.\")\n\n        if slice_size is not None and self.added_kv_proj_dim is not None:\n            processor = SlicedAttnAddedKVProcessor(slice_size)\n        elif slice_size is not None:\n            processor = SlicedAttnProcessor(slice_size)\n        elif self.added_kv_proj_dim is not None:\n            processor = AttnAddedKVProcessor()\n        else:\n            # set attention processor\n            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n\n        self.set_processor(processor)\n\n    def set_processor(self, processor: \"AttnProcessor\") -> None:\n        r\"\"\"\n        Set the attention processor to use.\n\n        Args:\n            processor (`AttnProcessor`):\n                The attention processor to use.\n        \"\"\"\n        # if current processor is in `self._modules` and if passed `processor` is not, we need to\n        # pop `processor` from `self._modules`\n        if (\n            hasattr(self, \"processor\")\n            and isinstance(self.processor, torch.nn.Module)\n            and not isinstance(processor, torch.nn.Module)\n        ):\n            logger.info(f\"You are removing possibly trained weights of {self.processor} with {processor}\")\n            self._modules.pop(\"processor\")\n\n        self.processor = processor\n\n    def get_processor(self, return_deprecated_lora: bool = False) -> \"AttentionProcessor\":\n        r\"\"\"\n        Get the attention processor in use.\n\n        Args:\n            return_deprecated_lora (`bool`, *optional*, defaults to `False`):\n                Set to `True` to return the deprecated LoRA attention processor.\n\n        Returns:\n            \"AttentionProcessor\": The attention processor in use.\n        \"\"\"\n        if not return_deprecated_lora:\n            return self.processor\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **cross_attention_kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        The forward method of the `Attention` class.\n\n        Args:\n            hidden_states (`torch.Tensor`):\n                The hidden states of the query.\n            encoder_hidden_states (`torch.Tensor`, *optional*):\n                The hidden states of the encoder.\n            attention_mask (`torch.Tensor`, *optional*):\n                The attention mask to use. If `None`, no mask is applied.\n            **cross_attention_kwargs:\n                Additional keyword arguments to pass along to the cross attention.\n\n        Returns:\n            `torch.Tensor`: The output of the attention layer.\n        \"\"\"\n        # The `Attention` class can call different attention processors / attention functions\n        # here we simply pass along all tensors to the selected processor class\n        # For standard processors that are defined here, `**cross_attention_kwargs` is empty\n\n        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())\n        quiet_attn_parameters = {\"ip_adapter_masks\", \"ip_hidden_states\"}\n        unused_kwargs = [\n            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters\n        ]\n        if len(unused_kwargs) > 0:\n            logger.warning(\n                f\"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored.\"\n            )\n        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}\n\n        return self.processor(\n            self,\n            hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            attention_mask=attention_mask,\n            **cross_attention_kwargs,\n        )\n\n    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`\n        is the number of heads initialized while constructing the `Attention` class.\n\n        Args:\n            tensor (`torch.Tensor`): The tensor to reshape.\n\n        Returns:\n            `torch.Tensor`: The reshaped tensor.\n        \"\"\"\n        head_size = self.heads\n        batch_size, seq_len, dim = tensor.shape\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)\n        return tensor\n\n    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:\n        r\"\"\"\n        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is\n        the number of heads initialized while constructing the `Attention` class.\n\n        Args:\n            tensor (`torch.Tensor`): The tensor to reshape.\n            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is\n                reshaped to `[batch_size * heads, seq_len, dim // heads]`.\n\n        Returns:\n            `torch.Tensor`: The reshaped tensor.\n        \"\"\"\n        head_size = self.heads\n        if tensor.ndim == 3:\n            batch_size, seq_len, dim = tensor.shape\n            extra_dim = 1\n        else:\n            batch_size, extra_dim, seq_len, dim = tensor.shape\n        tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)\n        tensor = tensor.permute(0, 2, 1, 3)\n\n        if out_dim == 3:\n            tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)\n\n        return tensor\n\n    def get_attention_scores(\n        self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        r\"\"\"\n        Compute the attention scores.\n\n        Args:\n            query (`torch.Tensor`): The query tensor.\n            key (`torch.Tensor`): The key tensor.\n            attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.\n\n        Returns:\n            `torch.Tensor`: The attention probabilities/scores.\n        \"\"\"\n        dtype = query.dtype\n        if self.upcast_attention:\n            query = query.float()\n            key = key.float()\n\n        if attention_mask is None:\n            baddbmm_input = torch.empty(\n                query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device\n            )\n            beta = 0\n        else:\n            baddbmm_input = attention_mask\n            beta = 1\n\n        attention_scores = torch.baddbmm(\n            baddbmm_input,\n            query,\n            key.transpose(-1, -2),\n            beta=beta,\n            alpha=self.scale,\n        )\n        del baddbmm_input\n\n        if self.upcast_softmax:\n            attention_scores = attention_scores.float()\n\n        attention_probs = attention_scores.softmax(dim=-1)\n        del attention_scores\n\n        attention_probs = attention_probs.to(dtype)\n\n        return attention_probs\n\n    def prepare_attention_mask(\n        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepare the attention mask for the attention computation.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                The attention mask to prepare.\n            target_length (`int`):\n                The target length of the attention mask. This is the length of the attention mask after padding.\n            batch_size (`int`):\n                The batch size, which is used to repeat the attention mask.\n            out_dim (`int`, *optional*, defaults to `3`):\n                The output dimension of the attention mask. Can be either `3` or `4`.\n\n        Returns:\n            `torch.Tensor`: The prepared attention mask.\n        \"\"\"\n        head_size = self.heads\n        if attention_mask is None:\n            return attention_mask\n\n        current_length: int = attention_mask.shape[-1]\n        if current_length != target_length:\n            if attention_mask.device.type == \"mps\":\n                # HACK: MPS: Does not support padding by greater than dimension of input tensor.\n                # Instead, we can manually construct the padding tensor.\n                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)\n                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat([attention_mask, padding], dim=2)\n            else:\n                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:\n                #       we want to instead pad by (0, remaining_length), where remaining_length is:\n                #       remaining_length: int = target_length - current_length\n                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n\n        if out_dim == 3:\n            if attention_mask.shape[0] < batch_size * head_size:\n                attention_mask = attention_mask.repeat_interleave(\n                    head_size, dim=0, output_size=attention_mask.shape[0] * head_size\n                )\n        elif out_dim == 4:\n            attention_mask = attention_mask.unsqueeze(1)\n            attention_mask = attention_mask.repeat_interleave(\n                head_size, dim=1, output_size=attention_mask.shape[1] * head_size\n            )\n\n        return attention_mask\n\n    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the\n        `Attention` class.\n\n        Args:\n            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.\n\n        Returns:\n            `torch.Tensor`: The normalized encoder hidden states.\n        \"\"\"\n        assert self.norm_cross is not None, \"self.norm_cross must be defined to call self.norm_encoder_hidden_states\"\n\n        if isinstance(self.norm_cross, nn.LayerNorm):\n            encoder_hidden_states = self.norm_cross(encoder_hidden_states)\n        elif isinstance(self.norm_cross, nn.GroupNorm):\n            # Group norm norms along the channels dimension and expects\n            # input to be in the shape of (N, C, *). In this case, we want\n            # to norm along the hidden dimension, so we need to move\n            # (batch_size, sequence_length, hidden_size) ->\n            # (batch_size, hidden_size, sequence_length)\n            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)\n            encoder_hidden_states = self.norm_cross(encoder_hidden_states)\n            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)\n        else:\n            assert False\n\n        return encoder_hidden_states\n\n    @torch.no_grad()\n    def fuse_projections(self, fuse=True):\n        device = self.to_q.weight.data.device\n        dtype = self.to_q.weight.data.dtype\n\n        if not self.is_cross_attention:\n            # fetch weight matrices.\n            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])\n            in_features = concatenated_weights.shape[1]\n            out_features = concatenated_weights.shape[0]\n\n            # create a new single projection layer and copy over the weights.\n            self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)\n            self.to_qkv.weight.copy_(concatenated_weights)\n            if self.use_bias:\n                concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])\n                self.to_qkv.bias.copy_(concatenated_bias)\n\n        else:\n            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])\n            in_features = concatenated_weights.shape[1]\n            out_features = concatenated_weights.shape[0]\n\n            self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)\n            self.to_kv.weight.copy_(concatenated_weights)\n            if self.use_bias:\n                concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])\n                self.to_kv.bias.copy_(concatenated_bias)\n\n        # handle added projections for SD3 and others.\n        if (\n            getattr(self, \"add_q_proj\", None) is not None\n            and getattr(self, \"add_k_proj\", None) is not None\n            and getattr(self, \"add_v_proj\", None) is not None\n        ):\n            concatenated_weights = torch.cat(\n                [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]\n            )\n            in_features = concatenated_weights.shape[1]\n            out_features = concatenated_weights.shape[0]\n\n            self.to_added_qkv = nn.Linear(\n                in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype\n            )\n            self.to_added_qkv.weight.copy_(concatenated_weights)\n            if self.added_proj_bias:\n                concatenated_bias = torch.cat(\n                    [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]\n                )\n                self.to_added_qkv.bias.copy_(concatenated_bias)\n\n        self.fused_projections = fuse\n\nclass AttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        *args,\n        **kwargs,\n    ) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\nclass UNetMidBlock2D(nn.Module):\n    \"\"\"\n    A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.\n\n    Args:\n        in_channels (`int`): The number of input channels.\n        temb_channels (`int`): The number of temporal embedding channels.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.\n        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.\n        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.\n        resnet_time_scale_shift (`str`, *optional*, defaults to `default`):\n            The type of normalization to apply to the time embeddings. This can help to improve the performance of the\n            model on tasks with long-range temporal dependencies.\n        resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.\n        resnet_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use in the group normalization layers of the resnet blocks.\n        attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.\n        resnet_pre_norm (`bool`, *optional*, defaults to `True`):\n            Whether to use pre-normalization for the resnet blocks.\n        add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.\n        attention_head_dim (`int`, *optional*, defaults to 1):\n            Dimension of a single attention head. The number of attention heads is determined based on this value and\n            the number of input channels.\n        output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.\n\n    Returns:\n        `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,\n        height, width)`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        attn_groups: Optional[int] = None,\n        resnet_pre_norm: bool = True,\n        add_attention: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n    ):\n        super().__init__()\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n        self.add_attention = add_attention\n\n        if attn_groups is None:\n            attn_groups = resnet_groups if resnet_time_scale_shift == \"default\" else None\n\n        # there is always at least one resnet\n        if resnet_time_scale_shift == \"spatial\":\n            resnets = [\n                ResnetBlockCondNorm2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=\"spatial\",\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                )\n            ]\n        else:\n            resnets = [\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            ]\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warning(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.\"\n            )\n            attention_head_dim = in_channels\n\n        for _ in range(num_layers):\n            if self.add_attention:\n                attentions.append(\n                    Attention(\n                        in_channels,\n                        heads=in_channels // attention_head_dim,\n                        dim_head=attention_head_dim,\n                        rescale_output_factor=output_scale_factor,\n                        eps=resnet_eps,\n                        norm_num_groups=attn_groups,\n                        spatial_norm_dim=temb_channels if resnet_time_scale_shift == \"spatial\" else None,\n                        residual_connection=True,\n                        bias=True,\n                        upcast_softmax=True,\n                        _from_deprecated_attn_block=True,\n                    )\n                )\n            else:\n                attentions.append(None)\n\n            if resnet_time_scale_shift == \"spatial\":\n                resnets.append(\n                    ResnetBlockCondNorm2D(\n                        in_channels=in_channels,\n                        out_channels=in_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=\"spatial\",\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                    )\n                )\n            else:\n                resnets.append(\n                    ResnetBlock2D(\n                        in_channels=in_channels,\n                        out_channels=in_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                    )\n                )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                if attn is not None:\n                    hidden_states = attn(hidden_states, temb=temb)\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)\n            else:\n                if attn is not None:\n                    hidden_states = attn(hidden_states, temb=temb)\n                hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\nclass DownEncoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        downsample_padding: int = 1,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            if resnet_time_scale_shift == \"spatial\":\n                resnets.append(\n                    ResnetBlockCondNorm2D(\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        temb_channels=None,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=\"spatial\",\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                    )\n                )\n            else:\n                resnets.append(\n                    ResnetBlock2D(\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        temb_channels=None,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                    )\n                )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=None)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states\n\n\nclass UpDecoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        temb_channels: Optional[int] = None,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            if resnet_time_scale_shift == \"spatial\":\n                resnets.append(\n                    ResnetBlockCondNorm2D(\n                        in_channels=input_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=\"spatial\",\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                    )\n                )\n            else:\n                resnets.append(\n                    ResnetBlock2D(\n                        in_channels=input_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                    )\n                )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\nclass Encoder(nn.Module):\n    r\"\"\"\n    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.\n\n    Args:\n        in_channels (`int`, *optional*, defaults to 3):\n            The number of input channels.\n        out_channels (`int`, *optional*, defaults to 3):\n            The number of output channels.\n        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `(\"DownEncoderBlock2D\",)`):\n            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available\n            options.\n        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):\n            The number of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2):\n            The number of layers per block.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups for normalization.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`):\n            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.\n        double_z (`bool`, *optional*, defaults to `True`):\n            Whether to double the number of output channels for the last block.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_block_types: Tuple[str, ...] = (\"DownEncoderBlock2D\",),\n        block_out_channels: Tuple[int, ...] = (64,),\n        layers_per_block: int = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        double_z: bool = True,\n        mid_block_add_attention=True,\n    ):\n        super().__init__()\n        self.layers_per_block = layers_per_block\n\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n        self.down_blocks = nn.ModuleList([])\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = DownEncoderBlock2D(\n                num_layers=self.layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-6,\n                downsample_padding=0,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                # attention_head_dim=output_channel,\n                # temb_channels=None,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.mid_block = UNetMidBlock2D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            output_scale_factor=1,\n            resnet_time_scale_shift=\"default\",\n            attention_head_dim=block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n            temb_channels=None,\n            add_attention=mid_block_add_attention,\n        )\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)\n        self.conv_act = nn.SiLU()\n\n        conv_out_channels = 2 * out_channels if double_z else out_channels\n        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, sample: torch.Tensor) -> torch.Tensor:\n        r\"\"\"The forward method of the `Encoder` class.\"\"\"\n\n        sample = self.conv_in(sample)\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            # down\n            for down_block in self.down_blocks:\n                sample = self._gradient_checkpointing_func(down_block, sample)\n            # middle\n            sample = self._gradient_checkpointing_func(self.mid_block, sample)\n\n        else:\n            # down\n            for down_block in self.down_blocks:\n                sample = down_block(sample)\n\n            # middle\n            sample = self.mid_block(sample)\n\n        # post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        return sample\n\nclass Decoder(nn.Module):\n    r\"\"\"\n    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.\n\n    Args:\n        in_channels (`int`, *optional*, defaults to 3):\n            The number of input channels.\n        out_channels (`int`, *optional*, defaults to 3):\n            The number of output channels.\n        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `(\"UpDecoderBlock2D\",)`):\n            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.\n        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):\n            The number of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2):\n            The number of layers per block.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups for normalization.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`):\n            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.\n        norm_type (`str`, *optional*, defaults to `\"group\"`):\n            The normalization type to use. Can be either `\"group\"` or `\"spatial\"`.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        up_block_types: Tuple[str, ...] = (\"UpDecoderBlock2D\",),\n        block_out_channels: Tuple[int, ...] = (64,),\n        layers_per_block: int = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        norm_type: str = \"group\",  # group, spatial\n        mid_block_add_attention=True,\n    ):\n        super().__init__()\n        self.layers_per_block = layers_per_block\n\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[-1],\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n        self.up_blocks = nn.ModuleList([])\n\n        temb_channels = in_channels if norm_type == \"spatial\" else None\n\n        # mid\n        self.mid_block = UNetMidBlock2D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            output_scale_factor=1,\n            resnet_time_scale_shift=\"default\" if norm_type == \"group\" else norm_type,\n            attention_head_dim=block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n            temb_channels=temb_channels,\n            add_attention=mid_block_add_attention,\n        )\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n\n            is_final_block = i == len(block_out_channels) - 1\n\n            up_block = UpDecoderBlock2D(\n                num_layers=self.layers_per_block + 1,\n                in_channels=prev_output_channel,\n                out_channels=output_channel,\n                # prev_output_channel=prev_output_channel,\n                add_upsample=not is_final_block,\n                resnet_eps=1e-6,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                # attention_head_dim=output_channel,\n                temb_channels=temb_channels,\n                resnet_time_scale_shift=norm_type,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_type == \"spatial\":\n            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)\n        else:\n            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)\n        self.conv_act = nn.SiLU()\n        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        latent_embeds: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        r\"\"\"The forward method of the `Decoder` class.\"\"\"\n\n        sample = self.conv_in(sample)\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            # middle\n            sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)\n\n            # up\n            for up_block in self.up_blocks:\n                sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)\n        else:\n            # middle\n            sample = self.mid_block(sample, latent_embeds)\n\n            # up\n            for up_block in self.up_blocks:\n                sample = up_block(sample, latent_embeds)\n\n        # post-process\n        if latent_embeds is None:\n            sample = self.conv_norm_out(sample)\n        else:\n            sample = self.conv_norm_out(sample, latent_embeds)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        return sample\n\n\nclass Flux2VAE(torch.nn.Module):\n    r\"\"\"\n    A VAE model with KL loss for encoding images into latents and decoding latent representations into images.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.\n        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"DownEncoderBlock2D\",)`):\n            Tuple of downsample block types.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpDecoderBlock2D\",)`):\n            Tuple of upsample block types.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):\n            Tuple of block output channels.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.\n        sample_size (`int`, *optional*, defaults to `32`): Sample input size.\n        force_upcast (`bool`, *optional*, default to `True`):\n            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE\n            can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`\n            can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix\n        mid_block_add_attention (`bool`, *optional*, default to `True`):\n            If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the\n            mid_block will only have resnet blocks\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"BasicTransformerBlock\", \"ResnetBlock2D\"]\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_block_types: Tuple[str, ...] = (\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n        ),\n        up_block_types: Tuple[str, ...] = (\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n        ),\n        block_out_channels: Tuple[int, ...] = (\n            128,\n            256,\n            512,\n            512,\n        ),\n        layers_per_block: int = 2,\n        act_fn: str = \"silu\",\n        latent_channels: int = 32,\n        norm_num_groups: int = 32,\n        sample_size: int = 1024,  # YiYi notes: not sure\n        force_upcast: bool = True,\n        use_quant_conv: bool = True,\n        use_post_quant_conv: bool = True,\n        mid_block_add_attention: bool = True,\n        batch_norm_eps: float = 1e-4,\n        batch_norm_momentum: float = 0.1,\n        patch_size: Tuple[int, int] = (2, 2),\n    ):\n        super().__init__()\n\n        # pass init params to Encoder\n        self.encoder = Encoder(\n            in_channels=in_channels,\n            out_channels=latent_channels,\n            down_block_types=down_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            double_z=True,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        # pass init params to Decoder\n        self.decoder = Decoder(\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            up_block_types=up_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            norm_num_groups=norm_num_groups,\n            act_fn=act_fn,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None\n        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None\n\n        self.bn = nn.BatchNorm2d(\n            math.prod(patch_size) * latent_channels,\n            eps=batch_norm_eps,\n            momentum=batch_norm_momentum,\n            affine=False,\n            track_running_stats=True,\n        )\n\n        self.use_slicing = False\n        self.use_tiling = False\n\n    @property\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self):\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(self, processor):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def _encode(self, x: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, height, width = x.shape\n\n        if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):\n            return self._tiled_encode(x)\n\n        enc = self.encoder(x)\n        if self.quant_conv is not None:\n            enc = self.quant_conv(enc)\n\n        return enc\n\n    def encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ):\n        \"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded images. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x)\n\n\n        h = rearrange(h, \"B C (H P) (W Q) -> B (C P Q) H W\", P=2, Q=2)\n        h = h[:, :128]\n        latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype)\n        latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(\n            h.device, h.dtype\n        )\n        h = (h - latents_bn_mean) / latents_bn_std\n        return h\n\n    def _decode(self, z: torch.Tensor, return_dict: bool = True):\n        if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):\n            return self.tiled_decode(z, return_dict=return_dict)\n\n        if self.post_quant_conv is not None:\n            z = self.post_quant_conv(z)\n\n        dec = self.decoder(z)\n\n        if not return_dict:\n            return (dec,)\n\n        return dec\n\n    def decode(\n        self, z: torch.FloatTensor, return_dict: bool = True, generator=None\n    ):\n        latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype)\n        latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(\n            z.device, z.dtype\n        )\n        z = z * latents_bn_std + latents_bn_mean\n        z = rearrange(z, \"B (C P Q) H W -> B C (H P) (W Q)\", P=2, Q=2)\n        \"\"\"\n        Decode a batch of images.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n\n        \"\"\"\n        if self.use_slicing and z.shape[0] > 1:\n            decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z)\n\n        if not return_dict:\n            return (decoded,)\n\n        return decoded\n\n    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        blend_extent = min(a.shape[2], b.shape[2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)\n        return b\n\n    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)\n        return b\n\n    def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        enc = torch.cat(result_rows, dim=2)\n        return enc\n\n    def tiled_encode(self, x: torch.Tensor, return_dict: bool = True):\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:\n                If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain\n                `tuple` is returned.\n        \"\"\"\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        moments = torch.cat(result_rows, dim=2)\n        return moments\n\n    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_sample_min_size - blend_extent\n\n        # Split z into overlapping 64x64 tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, z.shape[2], overlap_size):\n            row = []\n            for j in range(0, z.shape[3], overlap_size):\n                tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]\n                if self.config.use_post_quant_conv:\n                    tile = self.post_quant_conv(tile)\n                decoded = self.decoder(tile)\n                row.append(decoded)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        dec = torch.cat(result_rows, dim=2)\n        if not return_dict:\n            return (dec,)\n\n        return dec\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        return_dict: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ):\n        r\"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            sample_posterior (`bool`, *optional*, defaults to `False`):\n                Whether to sample from the posterior.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z).sample\n\n        if not return_dict:\n            return (dec,)\n\n        return dec\n"
  },
  {
    "path": "diffsynth/models/flux_controlnet.py",
    "content": "import torch\nfrom einops import rearrange, repeat\nfrom .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm\n# from .utils import hash_state_dict_keys, init_weights_on_device\nfrom contextlib import contextmanager\n\ndef hash_state_dict_keys(state_dict, with_shape=True):\n    keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)\n    keys_str = keys_str.encode(encoding=\"UTF-8\")\n    return hashlib.md5(keys_str).hexdigest()\n\n@contextmanager\ndef init_weights_on_device(device = torch.device(\"meta\"), include_buffers :bool = False):\n    \n    old_register_parameter = torch.nn.Module.register_parameter\n    if include_buffers:\n        old_register_buffer = torch.nn.Module.register_buffer\n    \n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        if param is not None:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n\n    def register_empty_buffer(module, name, buffer, persistent=True):\n        old_register_buffer(module, name, buffer, persistent=persistent)\n        if buffer is not None:\n            module._buffers[name] = module._buffers[name].to(device)\n            \n    def patch_tensor_constructor(fn):\n        def wrapper(*args, **kwargs):\n            kwargs[\"device\"] = device\n            return fn(*args, **kwargs)\n\n        return wrapper\n    \n    if include_buffers:\n        tensor_constructors_to_patch = {\n            torch_function_name: getattr(torch, torch_function_name)\n            for torch_function_name in [\"empty\", \"zeros\", \"ones\", \"full\"]\n        }\n    else:\n        tensor_constructors_to_patch = {}\n    \n    try:\n        torch.nn.Module.register_parameter = register_empty_parameter\n        if include_buffers:\n            torch.nn.Module.register_buffer = register_empty_buffer\n        for torch_function_name in tensor_constructors_to_patch.keys():\n            setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))\n        yield\n    finally:\n        torch.nn.Module.register_parameter = old_register_parameter\n        if include_buffers:\n            torch.nn.Module.register_buffer = old_register_buffer\n        for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():\n            setattr(torch, torch_function_name, old_torch_function)\n\nclass FluxControlNet(torch.nn.Module):\n    def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):\n        super().__init__()\n        self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])\n        self.time_embedder = TimestepEmbeddings(256, 3072)\n        self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)\n        self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))\n        self.context_embedder = torch.nn.Linear(4096, 3072)\n        self.x_embedder = torch.nn.Linear(64, 3072)\n\n        self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])\n        self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])\n\n        self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])\n        self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])\n        \n        self.mode_dict = mode_dict\n        self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None\n        self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)\n\n\n    def prepare_image_ids(self, latents):\n        batch_size, _, height, width = latents.shape\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)\n        latent_image_ids = latent_image_ids.reshape(\n            batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n        latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)\n\n        return latent_image_ids\n    \n\n    def patchify(self, hidden_states):\n        hidden_states = rearrange(hidden_states, \"B C (H P) (W Q) -> B (H W) (C P Q)\", P=2, Q=2)\n        return hidden_states\n    \n\n    def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):\n        if len(res_stack) == 0:\n            return [torch.zeros_like(hidden_states)] * num_blocks\n        interval = (num_blocks + len(res_stack) - 1) // len(res_stack)\n        aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]\n        return aligned_res_stack\n\n\n    def forward(\n        self,\n        hidden_states,\n        controlnet_conditioning,\n        timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,\n        processor_id=None,\n        tiled=False, tile_size=128, tile_stride=64,\n        **kwargs\n    ):\n        if image_ids is None:\n            image_ids = self.prepare_image_ids(hidden_states)\n\n        conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)\n        if self.guidance_embedder is not None:\n            guidance = guidance * 1000\n            conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)\n        prompt_emb = self.context_embedder(prompt_emb)\n        if self.controlnet_mode_embedder is not None: # Different from FluxDiT\n            processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)\n            processor_id = repeat(processor_id, \"D -> B D\", B=1).to(text_ids.device)\n            prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)\n            text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)\n        image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))\n\n        hidden_states = self.patchify(hidden_states)\n        hidden_states = self.x_embedder(hidden_states)\n        controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT\n        hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT\n\n        controlnet_res_stack = []\n        for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):\n            hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)\n            controlnet_res_stack.append(controlnet_block(hidden_states))\n\n        controlnet_single_res_stack = []\n        hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)\n        for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):\n            hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)\n            controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))\n\n        controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])\n        controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])\n\n        return controlnet_res_stack, controlnet_single_res_stack\n\n\n    # @staticmethod\n    # def state_dict_converter():\n    #     return FluxControlNetStateDictConverter()\n    \n    def quantize(self):\n        def cast_to(weight, dtype=None, device=None, copy=False):\n            if device is None or weight.device == device:\n                if not copy:\n                    if dtype is None or weight.dtype == dtype:\n                        return weight\n                return weight.to(dtype=dtype, copy=copy)\n\n            r = torch.empty_like(weight, dtype=dtype, device=device)\n            r.copy_(weight)\n            return r\n\n        def cast_weight(s, input=None, dtype=None, device=None):\n            if input is not None:\n                if dtype is None:\n                    dtype = input.dtype\n                if device is None:\n                    device = input.device\n            weight = cast_to(s.weight, dtype, device)\n            return weight\n\n        def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):\n            if input is not None:\n                if dtype is None:\n                    dtype = input.dtype\n                if bias_dtype is None:\n                    bias_dtype = dtype\n                if device is None:\n                    device = input.device\n            bias = None\n            weight = cast_to(s.weight, dtype, device)\n            bias = cast_to(s.bias, bias_dtype, device)\n            return weight, bias\n\n        class quantized_layer:\n            class QLinear(torch.nn.Linear):\n                def __init__(self, *args, **kwargs):\n                    super().__init__(*args, **kwargs)\n                    \n                def forward(self,input,**kwargs):\n                    weight,bias= cast_bias_weight(self,input)\n                    return torch.nn.functional.linear(input,weight,bias)\n            \n            class QRMSNorm(torch.nn.Module):\n                def __init__(self, module):\n                    super().__init__()\n                    self.module = module\n                    \n                def forward(self,hidden_states,**kwargs):\n                    weight= cast_weight(self.module,hidden_states)\n                    input_dtype = hidden_states.dtype\n                    variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)\n                    hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)\n                    hidden_states = hidden_states.to(input_dtype) * weight\n                    return hidden_states\n            \n            class QEmbedding(torch.nn.Embedding):\n                def __init__(self, *args, **kwargs):\n                    super().__init__(*args, **kwargs)\n                    \n                def forward(self,input,**kwargs):\n                    weight= cast_weight(self,input)\n                    return torch.nn.functional.embedding(\n                        input, weight, self.padding_idx, self.max_norm,\n                        self.norm_type, self.scale_grad_by_freq, self.sparse)\n            \n        def replace_layer(model):\n            for name, module in model.named_children():\n                if isinstance(module,quantized_layer.QRMSNorm):\n                    continue\n                if isinstance(module, torch.nn.Linear):\n                    with init_weights_on_device():\n                        new_layer = quantized_layer.QLinear(module.in_features,module.out_features)\n                    new_layer.weight = module.weight\n                    if module.bias is not None:\n                        new_layer.bias = module.bias\n                    setattr(model, name, new_layer)\n                elif isinstance(module, RMSNorm):\n                    if hasattr(module,\"quantized\"):\n                        continue\n                    module.quantized= True\n                    new_layer = quantized_layer.QRMSNorm(module)\n                    setattr(model, name, new_layer)\n                elif isinstance(module,torch.nn.Embedding):\n                    rows, cols = module.weight.shape\n                    new_layer = quantized_layer.QEmbedding(\n                        num_embeddings=rows,\n                        embedding_dim=cols,\n                        _weight=module.weight,\n                        # _freeze=module.freeze,\n                        padding_idx=module.padding_idx,\n                        max_norm=module.max_norm,\n                        norm_type=module.norm_type,\n                        scale_grad_by_freq=module.scale_grad_by_freq,\n                        sparse=module.sparse)\n                    setattr(model, name, new_layer)\n                else:\n                    replace_layer(module)\n\n        replace_layer(self)\n    \n\n\nclass FluxControlNetStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        hash_value = hash_state_dict_keys(state_dict)\n        global_rename_dict = {\n            \"context_embedder\": \"context_embedder\",\n            \"x_embedder\": \"x_embedder\",\n            \"time_text_embed.timestep_embedder.linear_1\": \"time_embedder.timestep_embedder.0\",\n            \"time_text_embed.timestep_embedder.linear_2\": \"time_embedder.timestep_embedder.2\",\n            \"time_text_embed.guidance_embedder.linear_1\": \"guidance_embedder.timestep_embedder.0\",\n            \"time_text_embed.guidance_embedder.linear_2\": \"guidance_embedder.timestep_embedder.2\",\n            \"time_text_embed.text_embedder.linear_1\": \"pooled_text_embedder.0\",\n            \"time_text_embed.text_embedder.linear_2\": \"pooled_text_embedder.2\",\n            \"norm_out.linear\": \"final_norm_out.linear\",\n            \"proj_out\": \"final_proj_out\",\n        }\n        rename_dict = {\n            \"proj_out\": \"proj_out\",\n            \"norm1.linear\": \"norm1_a.linear\",\n            \"norm1_context.linear\": \"norm1_b.linear\",\n            \"attn.to_q\": \"attn.a_to_q\",\n            \"attn.to_k\": \"attn.a_to_k\",\n            \"attn.to_v\": \"attn.a_to_v\",\n            \"attn.to_out.0\": \"attn.a_to_out\",\n            \"attn.add_q_proj\": \"attn.b_to_q\",\n            \"attn.add_k_proj\": \"attn.b_to_k\",\n            \"attn.add_v_proj\": \"attn.b_to_v\",\n            \"attn.to_add_out\": \"attn.b_to_out\",\n            \"ff.net.0.proj\": \"ff_a.0\",\n            \"ff.net.2\": \"ff_a.2\",\n            \"ff_context.net.0.proj\": \"ff_b.0\",\n            \"ff_context.net.2\": \"ff_b.2\",\n            \"attn.norm_q\": \"attn.norm_q_a\",\n            \"attn.norm_k\": \"attn.norm_k_a\",\n            \"attn.norm_added_q\": \"attn.norm_q_b\",\n            \"attn.norm_added_k\": \"attn.norm_k_b\",\n        }\n        rename_dict_single = {\n            \"attn.to_q\": \"a_to_q\",\n            \"attn.to_k\": \"a_to_k\",\n            \"attn.to_v\": \"a_to_v\",\n            \"attn.norm_q\": \"norm_q_a\",\n            \"attn.norm_k\": \"norm_k_a\",\n            \"norm.linear\": \"norm.linear\",\n            \"proj_mlp\": \"proj_in_besides_attn\",\n            \"proj_out\": \"proj_out\",\n        }\n        state_dict_ = {}\n        for name, param in state_dict.items():\n            if name.endswith(\".weight\") or name.endswith(\".bias\"):\n                suffix = \".weight\" if name.endswith(\".weight\") else \".bias\"\n                prefix = name[:-len(suffix)]\n                if prefix in global_rename_dict:\n                    state_dict_[global_rename_dict[prefix] + suffix] = param\n                elif prefix.startswith(\"transformer_blocks.\"):\n                    names = prefix.split(\".\")\n                    names[0] = \"blocks\"\n                    middle = \".\".join(names[2:])\n                    if middle in rename_dict:\n                        name_ = \".\".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])\n                        state_dict_[name_] = param\n                elif prefix.startswith(\"single_transformer_blocks.\"):\n                    names = prefix.split(\".\")\n                    names[0] = \"single_blocks\"\n                    middle = \".\".join(names[2:])\n                    if middle in rename_dict_single:\n                        name_ = \".\".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])\n                        state_dict_[name_] = param\n                    else:\n                        state_dict_[name] = param\n                else:\n                    state_dict_[name] = param\n        for name in list(state_dict_.keys()):\n            if \".proj_in_besides_attn.\" in name:\n                name_ = name.replace(\".proj_in_besides_attn.\", \".to_qkv_mlp.\")\n                param = torch.concat([\n                    state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_q.\")],\n                    state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_k.\")],\n                    state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_v.\")],\n                    state_dict_[name],\n                ], dim=0)\n                state_dict_[name_] = param\n                state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_q.\"))\n                state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_k.\"))\n                state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_v.\"))\n                state_dict_.pop(name)\n        for name in list(state_dict_.keys()):\n            for component in [\"a\", \"b\"]:\n                if f\".{component}_to_q.\" in name:\n                    name_ = name.replace(f\".{component}_to_q.\", f\".{component}_to_qkv.\")\n                    param = torch.concat([\n                        state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                        state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                        state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                    ], dim=0)\n                    state_dict_[name_] = param\n                    state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\"))\n                    state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\"))\n                    state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\"))\n        if hash_value == \"78d18b9101345ff695f312e7e62538c0\":\n            extra_kwargs = {\"num_mode\": 10, \"mode_dict\": {\"canny\": 0, \"tile\": 1, \"depth\": 2, \"blur\": 3, \"pose\": 4, \"gray\": 5, \"lq\": 6}}\n        elif hash_value == \"b001c89139b5f053c715fe772362dd2a\":\n            extra_kwargs = {\"num_single_blocks\": 0}\n        elif hash_value == \"52357cb26250681367488a8954c271e8\":\n            extra_kwargs = {\"num_joint_blocks\": 6, \"num_single_blocks\": 0, \"additional_input_dim\": 4}\n        elif hash_value == \"0cfd1740758423a2a854d67c136d1e8c\":\n            extra_kwargs = {\"num_joint_blocks\": 4, \"num_single_blocks\": 1}\n        elif hash_value == \"7f9583eb8ba86642abb9a21a4b2c9e16\":\n            extra_kwargs = {\"num_joint_blocks\": 4, \"num_single_blocks\": 10}\n        elif hash_value == \"43ad5aaa27dd4ee01b832ed16773fa52\":\n            extra_kwargs = {\"num_joint_blocks\": 6, \"num_single_blocks\": 0}\n        else:\n            extra_kwargs = {}\n        return state_dict_, extra_kwargs\n    \n\n    def from_civitai(self, state_dict):\n        return self.from_diffusers(state_dict)\n"
  },
  {
    "path": "diffsynth/models/flux_dit.py",
    "content": "import torch\nfrom .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm\nfrom einops import rearrange\n\n\ndef interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):\n    batch_size, num_tokens = hidden_states.shape[0:2]\n    ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)\n    ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)\n    hidden_states = hidden_states + scale * ip_hidden_states\n    return hidden_states\n\n\nclass RoPEEmbedding(torch.nn.Module):\n    def __init__(self, dim, theta, axes_dim):\n        super().__init__()\n        self.dim = dim\n        self.theta = theta\n        self.axes_dim = axes_dim\n\n\n    def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:\n        assert dim % 2 == 0, \"The dimension must be even.\"\n\n        scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim\n        omega = 1.0 / (theta**scale)\n\n        batch_size, seq_length = pos.shape\n        out = torch.einsum(\"...n,d->...nd\", pos, omega)\n        cos_out = torch.cos(out)\n        sin_out = torch.sin(out)\n\n        stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)\n        out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)\n        return out.float()\n\n\n    def forward(self, ids):\n        n_axes = ids.shape[-1]\n        emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)\n        return emb.unsqueeze(1)\n\n\n\nclass FluxJointAttention(torch.nn.Module):\n    def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.only_out_a = only_out_a\n\n        self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)\n        self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)\n\n        self.norm_q_a = RMSNorm(head_dim, eps=1e-6)\n        self.norm_k_a = RMSNorm(head_dim, eps=1e-6)\n        self.norm_q_b = RMSNorm(head_dim, eps=1e-6)\n        self.norm_k_b = RMSNorm(head_dim, eps=1e-6)\n\n        self.a_to_out = torch.nn.Linear(dim_a, dim_a)\n        if not only_out_a:\n            self.b_to_out = torch.nn.Linear(dim_b, dim_b)\n\n\n    def apply_rope(self, xq, xk, freqs_cis):\n        xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)\n        xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)\n        xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]\n        xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]\n        return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)\n\n    def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):\n        batch_size = hidden_states_a.shape[0]\n\n        # Part A\n        qkv_a = self.a_to_qkv(hidden_states_a)\n        qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)\n        q_a, k_a, v_a = qkv_a.chunk(3, dim=1)\n        q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)\n\n        # Part B\n        qkv_b = self.b_to_qkv(hidden_states_b)\n        qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)\n        q_b, k_b, v_b = qkv_b.chunk(3, dim=1)\n        q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)\n\n        q = torch.concat([q_b, q_a], dim=2)\n        k = torch.concat([k_b, k_a], dim=2)\n        v = torch.concat([v_b, v_a], dim=2)\n\n        q, k = self.apply_rope(q, k, image_rotary_emb)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n        hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]\n        if ipadapter_kwargs_list is not None:\n            hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)\n        hidden_states_a = self.a_to_out(hidden_states_a)\n        if self.only_out_a:\n            return hidden_states_a\n        else:\n            hidden_states_b = self.b_to_out(hidden_states_b)\n            return hidden_states_a, hidden_states_b\n\n\n\nclass FluxJointTransformerBlock(torch.nn.Module):\n    def __init__(self, dim, num_attention_heads):\n        super().__init__()\n        self.norm1_a = AdaLayerNorm(dim)\n        self.norm1_b = AdaLayerNorm(dim)\n\n        self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)\n\n        self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n        self.ff_a = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*4),\n            torch.nn.GELU(approximate=\"tanh\"),\n            torch.nn.Linear(dim*4, dim)\n        )\n\n        self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n        self.ff_b = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*4),\n            torch.nn.GELU(approximate=\"tanh\"),\n            torch.nn.Linear(dim*4, dim)\n        )\n\n\n    def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):\n        norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)\n        norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)\n\n        # Attention\n        attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)\n\n        # Part A\n        hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a\n        norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a\n        hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)\n\n        # Part B\n        hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b\n        norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b\n        hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)\n\n        return hidden_states_a, hidden_states_b\n\n\n\nclass FluxSingleAttention(torch.nn.Module):\n    def __init__(self, dim_a, dim_b, num_heads, head_dim):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)\n\n        self.norm_q_a = RMSNorm(head_dim, eps=1e-6)\n        self.norm_k_a = RMSNorm(head_dim, eps=1e-6)\n\n\n    def apply_rope(self, xq, xk, freqs_cis):\n        xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)\n        xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)\n        xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]\n        xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]\n        return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)\n\n\n    def forward(self, hidden_states, image_rotary_emb):\n        batch_size = hidden_states.shape[0]\n\n        qkv_a = self.a_to_qkv(hidden_states)\n        qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)\n        q_a, k_a, v = qkv_a.chunk(3, dim=1)\n        q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)\n\n        q, k = self.apply_rope(q_a, k_a, image_rotary_emb)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n        return hidden_states\n\n\n\nclass AdaLayerNormSingle(torch.nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.silu = torch.nn.SiLU()\n        self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)\n        self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n\n\n    def forward(self, x, emb):\n        emb = self.linear(self.silu(emb))\n        shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)\n        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]\n        return x, gate_msa\n\n\n\nclass FluxSingleTransformerBlock(torch.nn.Module):\n    def __init__(self, dim, num_attention_heads):\n        super().__init__()\n        self.num_heads = num_attention_heads\n        self.head_dim = dim // num_attention_heads\n        self.dim = dim\n\n        self.norm = AdaLayerNormSingle(dim)\n        self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))\n        self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)\n        self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)\n\n        self.proj_out = torch.nn.Linear(dim * 5, dim)\n\n\n    def apply_rope(self, xq, xk, freqs_cis):\n        xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)\n        xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)\n        xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]\n        xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]\n        return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)\n\n\n    def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):\n        batch_size = hidden_states.shape[0]\n\n        qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)\n        q, k, v = qkv.chunk(3, dim=1)\n        q, k = self.norm_q_a(q), self.norm_k_a(k)\n\n        q, k = self.apply_rope(q, k, image_rotary_emb)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n        if ipadapter_kwargs_list is not None:\n            hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)\n        return hidden_states\n\n\n    def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):\n        residual = hidden_states_a\n        norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)\n        hidden_states_a = self.to_qkv_mlp(norm_hidden_states)\n        attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]\n\n        attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)\n        mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate=\"tanh\")\n\n        hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)\n        hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)\n        hidden_states_a = residual + hidden_states_a\n\n        return hidden_states_a, hidden_states_b\n\n\n\nclass AdaLayerNormContinuous(torch.nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.silu = torch.nn.SiLU()\n        self.linear = torch.nn.Linear(dim, dim * 2, bias=True)\n        self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)\n\n    def forward(self, x, conditioning):\n        emb = self.linear(self.silu(conditioning))\n        shift, scale = torch.chunk(emb, 2, dim=1)\n        x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]\n        return x\n\n\n\nclass FluxDiT(torch.nn.Module):\n    def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):\n        super().__init__()\n        self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])\n        self.time_embedder = TimestepEmbeddings(256, 3072)\n        self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)\n        self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))\n        self.context_embedder = torch.nn.Linear(4096, 3072)\n        self.x_embedder = torch.nn.Linear(input_dim, 3072)\n\n        self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])\n        self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])\n\n        self.final_norm_out = AdaLayerNormContinuous(3072)\n        self.final_proj_out = torch.nn.Linear(3072, 64)\n        \n        self.input_dim = input_dim\n\n\n    def patchify(self, hidden_states):\n        hidden_states = rearrange(hidden_states, \"B C (H P) (W Q) -> B (H W) (C P Q)\", P=2, Q=2)\n        return hidden_states\n\n\n    def unpatchify(self, hidden_states, height, width):\n        hidden_states = rearrange(hidden_states, \"B (H W) (C P Q) -> B C (H P) (W Q)\", P=2, Q=2, H=height//2, W=width//2)\n        return hidden_states\n\n\n    def prepare_image_ids(self, latents):\n        batch_size, _, height, width = latents.shape\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)\n        latent_image_ids = latent_image_ids.reshape(\n            batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n        latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)\n\n        return latent_image_ids\n\n\n    def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):\n        N = len(entity_masks)\n        batch_size = entity_masks[0].shape[0]\n        total_seq_len = N * prompt_seq_len + image_seq_len\n        patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]\n        attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)\n\n        image_start = N * prompt_seq_len\n        image_end = N * prompt_seq_len + image_seq_len\n        # prompt-image mask\n        for i in range(N):\n            prompt_start = i * prompt_seq_len\n            prompt_end = (i + 1) * prompt_seq_len\n            image_mask = torch.sum(patched_masks[i], dim=-1) > 0\n            image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)\n            # prompt update with image\n            attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask\n            # image update with prompt\n            attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)\n        # prompt-prompt mask\n        for i in range(N):\n            for j in range(N):\n                if i != j:\n                    prompt_start_i = i * prompt_seq_len\n                    prompt_end_i = (i + 1) * prompt_seq_len\n                    prompt_start_j = j * prompt_seq_len\n                    prompt_end_j = (j + 1) * prompt_seq_len\n                    attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False\n\n        attention_mask = attention_mask.float()\n        attention_mask[attention_mask == 0] = float('-inf')\n        attention_mask[attention_mask == 1] = 0\n        return attention_mask\n\n\n    def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):\n        max_masks = 0\n        attention_mask = None\n        prompt_embs = [prompt_emb]\n        if entity_masks is not None:\n            # entity_masks\n            batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]\n            entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)\n            entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]\n            # global mask\n            global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)\n            entity_masks = entity_masks + [global_mask] # append global to last\n            # attention mask\n            attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])\n            attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)\n            attention_mask = attention_mask.unsqueeze(1)\n            # embds: n_masks * b * seq * d\n            local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]\n            prompt_embs = local_embs + prompt_embs # append global to last\n        prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]\n        prompt_emb = torch.cat(prompt_embs, dim=1)\n\n        # positional embedding\n        text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)\n        image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))\n        return prompt_emb, image_rotary_emb, attention_mask\n\n\n    def forward(\n        self,\n        hidden_states,\n        timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,\n        tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,\n        use_gradient_checkpointing=False,\n        **kwargs\n    ):\n        # (Deprecated) The real forward is in `pipelines.flux_image`.\n        return None\n"
  },
  {
    "path": "diffsynth/models/flux_infiniteyou.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\n\n\n# FFN\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    #(bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass InfiniteYouImageProjector(nn.Module):\n\n    def __init__(\n        self,\n        dim=1280,\n        depth=4,\n        dim_head=64,\n        heads=20,\n        num_queries=8,\n        embedding_dim=512,\n        output_dim=4096,\n        ff_mult=4,\n    ):\n        super().__init__()\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList([\n                    PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                    FeedForward(dim=dim, mult=ff_mult),\n                ]))\n\n    def forward(self, x):\n\n        latents = self.latents.repeat(x.size(0), 1, 1)\n        latents = latents.to(dtype=x.dtype, device=x.device)\n\n        x = self.proj_in(x)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n    @staticmethod\n    def state_dict_converter():\n        return FluxInfiniteYouImageProjectorStateDictConverter()\n\n\nclass FluxInfiniteYouImageProjectorStateDictConverter:\n\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        return state_dict['image_proj']\n"
  },
  {
    "path": "diffsynth/models/flux_ipadapter.py",
    "content": "from .general_modules import RMSNorm\nfrom transformers import SiglipVisionModel, SiglipVisionConfig\nimport torch\n\n\nclass SiglipVisionModelSO400M(SiglipVisionModel):\n    def __init__(self):\n        config = SiglipVisionConfig(\n            hidden_size=1152,\n            image_size=384,\n            intermediate_size=4304,\n            model_type=\"siglip_vision_model\",\n            num_attention_heads=16,\n            num_hidden_layers=27,\n            patch_size=14,\n            architectures=[\"SiglipModel\"],\n            initializer_factor=1.0,\n            torch_dtype=\"float32\",\n            transformers_version=\"4.37.0.dev0\"\n        )\n        super().__init__(config)\n\nclass MLPProjModel(torch.nn.Module):\n    def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):\n        super().__init__()\n        \n        self.cross_attention_dim = cross_attention_dim\n        self.num_tokens = num_tokens\n        \n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),\n            torch.nn.GELU(),\n            torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),\n        )\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)\n        \n    def forward(self, id_embeds):\n        x = self.proj(id_embeds)\n        x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)\n        x = self.norm(x)\n        return x\n\nclass IpAdapterModule(torch.nn.Module):\n    def __init__(self, num_attention_heads, attention_head_dim, input_dim):\n        super().__init__()\n        self.num_heads = num_attention_heads\n        self.head_dim = attention_head_dim\n        output_dim = num_attention_heads * attention_head_dim\n        self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)\n        self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)\n        self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)\n        \n\n    def forward(self, hidden_states):\n        batch_size = hidden_states.shape[0]\n        # ip_k\n        ip_k = self.to_k_ip(hidden_states)\n        ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        ip_k = self.norm_added_k(ip_k)\n        # ip_v\n        ip_v = self.to_v_ip(hidden_states)\n        ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        return ip_k, ip_v\n\n\nclass FluxIpAdapter(torch.nn.Module):\n    def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):\n        super().__init__()\n        self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])\n        self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)\n        self.set_adapter()\n\n    def set_adapter(self):\n        self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}\n\n    def forward(self, hidden_states, scale=1.0):\n        hidden_states = self.image_proj(hidden_states)\n        hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])\n        ip_kv_dict = {}\n        for block_id in self.call_block_id:\n            ipadapter_id = self.call_block_id[block_id]\n            ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)\n            ip_kv_dict[block_id] = {\n                \"ip_k\": ip_k,\n                \"ip_v\": ip_v,\n                \"scale\": scale\n            }\n        return ip_kv_dict\n\n    @staticmethod\n    def state_dict_converter():\n        return FluxIpAdapterStateDictConverter()\n\n\nclass FluxIpAdapterStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        state_dict_ = {}\n        for name in state_dict[\"ip_adapter\"]:\n            name_ = 'ipadapter_modules.' + name\n            state_dict_[name_] = state_dict[\"ip_adapter\"][name]\n        for name in state_dict[\"image_proj\"]:\n            name_ = \"image_proj.\" + name\n            state_dict_[name_] = state_dict[\"image_proj\"][name]\n        return state_dict_\n    \n    def from_civitai(self, state_dict):\n        return self.from_diffusers(state_dict)\n"
  },
  {
    "path": "diffsynth/models/flux_lora_encoder.py",
    "content": "import torch\nfrom einops import rearrange\n\n\ndef low_version_attention(query, key, value, attn_bias=None):\n    scale = 1 / query.shape[-1] ** 0.5\n    query = query * scale\n    attn = torch.matmul(query, key.transpose(-2, -1))\n    if attn_bias is not None:\n        attn = attn + attn_bias\n    attn = attn.softmax(-1)\n    return attn @ value\n\n\nclass Attention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)\n        self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)\n\n    def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):\n        batch_size = q.shape[0]\n        ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)\n        hidden_states = hidden_states + scale * ip_hidden_states\n        return hidden_states\n\n    def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        batch_size = encoder_hidden_states.shape[0]\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n\n        if qkv_preprocessor is not None:\n            q, k, v = qkv_preprocessor(q, k, v)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        if ipadapter_kwargs is not None:\n            hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n    \n    def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = rearrange(q, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n        k = rearrange(k, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n        v = rearrange(v, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n\n        if attn_mask is not None:\n            hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)\n        else:\n            import xformers.ops as xops\n            hidden_states = xops.memory_efficient_attention(q, k, v)\n        hidden_states = rearrange(hidden_states, \"(b n) f d -> b f (n d)\", n=self.num_heads)\n\n        hidden_states = hidden_states.to(q.dtype)\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):\n        return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)\n\n\n\n\n\nclass CLIPEncoderLayer(torch.nn.Module):\n    def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):\n        super().__init__()\n        self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)\n        self.layer_norm1 = torch.nn.LayerNorm(embed_dim)\n        self.layer_norm2 = torch.nn.LayerNorm(embed_dim)\n        self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)\n        self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)\n\n        self.use_quick_gelu = use_quick_gelu\n\n    def quickGELU(self, x):\n        return x * torch.sigmoid(1.702 * x)\n    \n    def forward(self, hidden_states, attn_mask=None):\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.attn(hidden_states, attn_mask=attn_mask)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.fc1(hidden_states)\n        if self.use_quick_gelu:\n            hidden_states = self.quickGELU(hidden_states)\n        else:\n            hidden_states = torch.nn.functional.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n    \n\nclass SDTextEncoder(torch.nn.Module):\n    def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):\n        super().__init__()\n\n        # token_embedding\n        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)\n\n        # position_embeds (This is a fixed tensor)\n        self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))\n\n        # encoders\n        self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])\n\n        # attn_mask\n        self.attn_mask = self.attention_mask(max_position_embeddings)\n\n        # final_layer_norm\n        self.final_layer_norm = torch.nn.LayerNorm(embed_dim)\n\n    def attention_mask(self, length):\n        mask = torch.empty(length, length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)\n        return mask\n\n    def forward(self, input_ids, clip_skip=1):\n        embeds = self.token_embedding(input_ids) + self.position_embeds\n        attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)\n        for encoder_id, encoder in enumerate(self.encoders):\n            embeds = encoder(embeds, attn_mask=attn_mask)\n            if encoder_id + clip_skip == len(self.encoders):\n                break\n        embeds = self.final_layer_norm(embeds)\n        return embeds\n    \n    @staticmethod\n    def state_dict_converter():\n        return SDTextEncoderStateDictConverter()\n\n\nclass SDTextEncoderStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        rename_dict = {\n            \"text_model.embeddings.token_embedding.weight\": \"token_embedding.weight\",\n            \"text_model.embeddings.position_embedding.weight\": \"position_embeds\",\n            \"text_model.final_layer_norm.weight\": \"final_layer_norm.weight\",\n            \"text_model.final_layer_norm.bias\": \"final_layer_norm.bias\"\n        }\n        attn_rename_dict = {\n            \"self_attn.q_proj\": \"attn.to_q\",\n            \"self_attn.k_proj\": \"attn.to_k\",\n            \"self_attn.v_proj\": \"attn.to_v\",\n            \"self_attn.out_proj\": \"attn.to_out\",\n            \"layer_norm1\": \"layer_norm1\",\n            \"layer_norm2\": \"layer_norm2\",\n            \"mlp.fc1\": \"fc1\",\n            \"mlp.fc2\": \"fc2\",\n        }\n        state_dict_ = {}\n        for name in state_dict:\n            if name in rename_dict:\n                param = state_dict[name]\n                if name == \"text_model.embeddings.position_embedding.weight\":\n                    param = param.reshape((1, param.shape[0], param.shape[1]))\n                state_dict_[rename_dict[name]] = param\n            elif name.startswith(\"text_model.encoder.layers.\"):\n                param = state_dict[name]\n                names = name.split(\".\")\n                layer_id, layer_type, tail = names[3], \".\".join(names[4:-1]), names[-1]\n                name_ = \".\".join([\"encoders\", layer_id, attn_rename_dict[layer_type], tail])\n                state_dict_[name_] = param\n        return state_dict_\n    \n    def from_civitai(self, state_dict):\n        rename_dict = {\n            \"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight\": \"token_embedding.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias\": \"encoders.0.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight\": \"encoders.0.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias\": \"encoders.0.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight\": \"encoders.0.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias\": \"encoders.0.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight\": \"encoders.0.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias\": \"encoders.0.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight\": \"encoders.0.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias\": \"encoders.0.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight\": \"encoders.0.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias\": \"encoders.0.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight\": \"encoders.0.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias\": \"encoders.0.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight\": \"encoders.0.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias\": \"encoders.0.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight\": \"encoders.0.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias\": \"encoders.1.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight\": \"encoders.1.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias\": \"encoders.1.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight\": \"encoders.1.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias\": \"encoders.1.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight\": \"encoders.1.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias\": \"encoders.1.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight\": \"encoders.1.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias\": \"encoders.1.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight\": \"encoders.1.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias\": \"encoders.1.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight\": \"encoders.1.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias\": \"encoders.1.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight\": \"encoders.1.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias\": \"encoders.1.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight\": \"encoders.1.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias\": \"encoders.10.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight\": \"encoders.10.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias\": \"encoders.10.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight\": \"encoders.10.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias\": \"encoders.10.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight\": \"encoders.10.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias\": \"encoders.10.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight\": \"encoders.10.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias\": \"encoders.10.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight\": \"encoders.10.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias\": \"encoders.10.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight\": \"encoders.10.attn.to_out.weight\",        \n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias\": \"encoders.10.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight\": \"encoders.10.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias\": \"encoders.10.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight\": \"encoders.10.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias\": \"encoders.11.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight\": \"encoders.11.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias\": \"encoders.11.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight\": \"encoders.11.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias\": \"encoders.11.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight\": \"encoders.11.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias\": \"encoders.11.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight\": \"encoders.11.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias\": \"encoders.11.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight\": \"encoders.11.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias\": \"encoders.11.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight\": \"encoders.11.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias\": \"encoders.11.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight\": \"encoders.11.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias\": \"encoders.11.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight\": \"encoders.11.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias\": \"encoders.2.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight\": \"encoders.2.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias\": \"encoders.2.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight\": \"encoders.2.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias\": \"encoders.2.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight\": \"encoders.2.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias\": \"encoders.2.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight\": \"encoders.2.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias\": \"encoders.2.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight\": \"encoders.2.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias\": \"encoders.2.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight\": \"encoders.2.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias\": \"encoders.2.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight\": \"encoders.2.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias\": \"encoders.2.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight\": \"encoders.2.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias\": \"encoders.3.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight\": \"encoders.3.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias\": \"encoders.3.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight\": \"encoders.3.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias\": \"encoders.3.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight\": \"encoders.3.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias\": \"encoders.3.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight\": \"encoders.3.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias\": \"encoders.3.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight\": \"encoders.3.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias\": \"encoders.3.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight\": \"encoders.3.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias\": \"encoders.3.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight\": \"encoders.3.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias\": \"encoders.3.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight\": \"encoders.3.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias\": \"encoders.4.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight\": \"encoders.4.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias\": \"encoders.4.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight\": \"encoders.4.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias\": \"encoders.4.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight\": \"encoders.4.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias\": \"encoders.4.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight\": \"encoders.4.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias\": \"encoders.4.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight\": \"encoders.4.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias\": \"encoders.4.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight\": \"encoders.4.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias\": \"encoders.4.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight\": \"encoders.4.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias\": \"encoders.4.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight\": \"encoders.4.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias\": \"encoders.5.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight\": \"encoders.5.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias\": \"encoders.5.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight\": \"encoders.5.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias\": \"encoders.5.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight\": \"encoders.5.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias\": \"encoders.5.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight\": \"encoders.5.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias\": \"encoders.5.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight\": \"encoders.5.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias\": \"encoders.5.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight\": \"encoders.5.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias\": \"encoders.5.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight\": \"encoders.5.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias\": \"encoders.5.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight\": \"encoders.5.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias\": \"encoders.6.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight\": \"encoders.6.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias\": \"encoders.6.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight\": \"encoders.6.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias\": \"encoders.6.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight\": \"encoders.6.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias\": \"encoders.6.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight\": \"encoders.6.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias\": \"encoders.6.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight\": \"encoders.6.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias\": \"encoders.6.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight\": \"encoders.6.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias\": \"encoders.6.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight\": \"encoders.6.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias\": \"encoders.6.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight\": \"encoders.6.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias\": \"encoders.7.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight\": \"encoders.7.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias\": \"encoders.7.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight\": \"encoders.7.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias\": \"encoders.7.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight\": \"encoders.7.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias\": \"encoders.7.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight\": \"encoders.7.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias\": \"encoders.7.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight\": \"encoders.7.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias\": \"encoders.7.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight\": \"encoders.7.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias\": \"encoders.7.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight\": \"encoders.7.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias\": \"encoders.7.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight\": \"encoders.7.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias\": \"encoders.8.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight\": \"encoders.8.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias\": \"encoders.8.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight\": \"encoders.8.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias\": \"encoders.8.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight\": \"encoders.8.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias\": \"encoders.8.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight\": \"encoders.8.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias\": \"encoders.8.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight\": \"encoders.8.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias\": \"encoders.8.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight\": \"encoders.8.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias\": \"encoders.8.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight\": \"encoders.8.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias\": \"encoders.8.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight\": \"encoders.8.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias\": \"encoders.9.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight\": \"encoders.9.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias\": \"encoders.9.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight\": \"encoders.9.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias\": \"encoders.9.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight\": \"encoders.9.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias\": \"encoders.9.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight\": \"encoders.9.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias\": \"encoders.9.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight\": \"encoders.9.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias\": \"encoders.9.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight\": \"encoders.9.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias\": \"encoders.9.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight\": \"encoders.9.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias\": \"encoders.9.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight\": \"encoders.9.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.final_layer_norm.bias\": \"final_layer_norm.bias\",\n            \"cond_stage_model.transformer.text_model.final_layer_norm.weight\": \"final_layer_norm.weight\",\n            \"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight\": \"position_embeds\"\n        }\n        state_dict_ = {}\n        for name in state_dict:\n            if name in rename_dict:\n                param = state_dict[name]\n                if name == \"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight\":\n                    param = param.reshape((1, param.shape[0], param.shape[1]))\n                state_dict_[rename_dict[name]] = param\n        return state_dict_\n\n\n\nclass LoRALayerBlock(torch.nn.Module):\n    def __init__(self, L, dim_in, dim_out):\n        super().__init__()\n        self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))\n        self.layer_norm = torch.nn.LayerNorm(dim_out)\n\n    def forward(self, lora_A, lora_B):\n        x = self.x @ lora_A.T @ lora_B.T\n        x = self.layer_norm(x)\n        return x\n    \n\nclass LoRAEmbedder(torch.nn.Module):\n    def __init__(self, lora_patterns=None, L=1, out_dim=2048):\n        super().__init__()\n        if lora_patterns is None:\n            lora_patterns = self.default_lora_patterns()\n            \n        model_dict = {}\n        for lora_pattern in lora_patterns:\n            name, dim = lora_pattern[\"name\"], lora_pattern[\"dim\"]\n            model_dict[name.replace(\".\", \"___\")] = LoRALayerBlock(L, dim[0], dim[1])\n        self.model_dict = torch.nn.ModuleDict(model_dict)\n        \n        proj_dict = {}\n        for lora_pattern in lora_patterns:\n            layer_type, dim = lora_pattern[\"type\"], lora_pattern[\"dim\"]\n            if layer_type not in proj_dict:\n                proj_dict[layer_type.replace(\".\", \"___\")] = torch.nn.Linear(dim[1], out_dim)\n        self.proj_dict = torch.nn.ModuleDict(proj_dict)\n        \n        self.lora_patterns = lora_patterns\n        \n        \n    def default_lora_patterns(self):\n        lora_patterns = []\n        lora_dict = {\n            \"attn.a_to_qkv\": (3072, 9216), \"attn.a_to_out\": (3072, 3072), \"ff_a.0\": (3072, 12288), \"ff_a.2\": (12288, 3072), \"norm1_a.linear\": (3072, 18432),\n            \"attn.b_to_qkv\": (3072, 9216), \"attn.b_to_out\": (3072, 3072), \"ff_b.0\": (3072, 12288), \"ff_b.2\": (12288, 3072), \"norm1_b.linear\": (3072, 18432),\n        }\n        for i in range(19):\n            for suffix in lora_dict:\n                lora_patterns.append({\n                    \"name\": f\"blocks.{i}.{suffix}\",\n                    \"dim\": lora_dict[suffix],\n                    \"type\": suffix,\n                })\n        lora_dict = {\"to_qkv_mlp\": (3072, 21504), \"proj_out\": (15360, 3072), \"norm.linear\": (3072, 9216)}\n        for i in range(38):\n            for suffix in lora_dict:\n                lora_patterns.append({\n                    \"name\": f\"single_blocks.{i}.{suffix}\",\n                    \"dim\": lora_dict[suffix],\n                    \"type\": suffix,\n                })\n        return lora_patterns\n        \n    def forward(self, lora):\n        lora_emb = []\n        for lora_pattern in self.lora_patterns:\n            name, layer_type = lora_pattern[\"name\"], lora_pattern[\"type\"]\n            lora_A = lora[name + \".lora_A.weight\"]\n            lora_B = lora[name + \".lora_B.weight\"]\n            lora_out = self.model_dict[name.replace(\".\", \"___\")](lora_A, lora_B)\n            lora_out = self.proj_dict[layer_type.replace(\".\", \"___\")](lora_out)\n            lora_emb.append(lora_out)\n        lora_emb = torch.concat(lora_emb, dim=1)\n        return lora_emb\n    \n    \nclass FluxLoRAEncoder(torch.nn.Module):\n    def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):\n        super().__init__()\n        self.num_embeds_per_lora = num_embeds_per_lora\n        # embedder\n        self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)\n        \n        # encoders\n        self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])\n\n        # special embedding\n        self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))\n        self.num_special_embeds = num_special_embeds\n        \n        # final layer\n        self.final_layer_norm = torch.nn.LayerNorm(embed_dim)\n        self.final_linear = torch.nn.Linear(embed_dim, embed_dim)\n\n    def forward(self, lora):\n        lora_embeds = self.embedder(lora)\n        special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)\n        embeds = torch.concat([special_embeds, lora_embeds], dim=1)\n        for encoder_id, encoder in enumerate(self.encoders):\n            embeds = encoder(embeds)\n        embeds = embeds[:, :self.num_special_embeds]\n        embeds = self.final_layer_norm(embeds)\n        embeds = self.final_linear(embeds)\n        return embeds\n    \n    @staticmethod\n    def state_dict_converter():\n        return FluxLoRAEncoderStateDictConverter()\n\n\nclass FluxLoRAEncoderStateDictConverter:\n    def from_civitai(self, state_dict):\n        return state_dict\n"
  },
  {
    "path": "diffsynth/models/flux_lora_patcher.py",
    "content": "import torch, math\nfrom ..core.loader import load_state_dict\nfrom typing import Union\n\nclass GeneralLoRALoader:\n    def __init__(self, device=\"cpu\", torch_dtype=torch.float32):\n        self.device = device\n        self.torch_dtype = torch_dtype\n    \n    \n    def get_name_dict(self, lora_state_dict):\n        lora_name_dict = {}\n        for key in lora_state_dict:\n            if \".lora_B.\" not in key:\n                continue\n            keys = key.split(\".\")\n            if len(keys) > keys.index(\"lora_B\") + 2:\n                keys.pop(keys.index(\"lora_B\") + 1)\n            keys.pop(keys.index(\"lora_B\"))\n            if keys[0] == \"diffusion_model\":\n                keys.pop(0)\n            keys.pop(-1)\n            target_name = \".\".join(keys)\n            lora_name_dict[target_name] = (key, key.replace(\".lora_B.\", \".lora_A.\"))\n        return lora_name_dict\n\n\n    def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):\n        updated_num = 0\n        lora_name_dict = self.get_name_dict(state_dict_lora)\n        for name, module in model.named_modules():\n            if name in lora_name_dict:\n                weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)\n                weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)\n                if len(weight_up.shape) == 4:\n                    weight_up = weight_up.squeeze(3).squeeze(2)\n                    weight_down = weight_down.squeeze(3).squeeze(2)\n                    weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)\n                else:\n                    weight_lora = alpha * torch.mm(weight_up, weight_down)\n                state_dict = module.state_dict()\n                state_dict[\"weight\"] = state_dict[\"weight\"].to(device=self.device, dtype=self.torch_dtype) + weight_lora\n                module.load_state_dict(state_dict)\n                updated_num += 1\n        print(f\"{updated_num} tensors are updated by LoRA.\")\n\nclass FluxLoRALoader(GeneralLoRALoader):\n    def __init__(self, device=\"cpu\", torch_dtype=torch.float32):\n        super().__init__(device=device, torch_dtype=torch_dtype)\n    \n        self.diffusers_rename_dict = {\n            \"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight\":\"single_blocks.blockid.a_to_k.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight\":\"single_blocks.blockid.a_to_k.lora_B.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight\":\"single_blocks.blockid.a_to_q.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight\":\"single_blocks.blockid.a_to_q.lora_B.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight\":\"single_blocks.blockid.a_to_v.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight\":\"single_blocks.blockid.a_to_v.lora_B.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight\":\"single_blocks.blockid.norm.linear.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight\":\"single_blocks.blockid.norm.linear.lora_B.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight\":\"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight\":\"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight\":\"single_blocks.blockid.proj_out.lora_A.default.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight\":\"single_blocks.blockid.proj_out.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_k.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_k.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_q.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_q.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_v.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_v.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight\":\"blocks.blockid.attn.b_to_out.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight\":\"blocks.blockid.attn.b_to_out.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight\":\"blocks.blockid.attn.a_to_k.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight\":\"blocks.blockid.attn.a_to_k.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight\":\"blocks.blockid.attn.a_to_out.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight\":\"blocks.blockid.attn.a_to_out.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight\":\"blocks.blockid.attn.a_to_q.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight\":\"blocks.blockid.attn.a_to_q.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight\":\"blocks.blockid.attn.a_to_v.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight\":\"blocks.blockid.attn.a_to_v.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight\":\"blocks.blockid.ff_a.0.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight\":\"blocks.blockid.ff_a.0.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight\":\"blocks.blockid.ff_a.2.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight\":\"blocks.blockid.ff_a.2.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight\":\"blocks.blockid.ff_b.0.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight\":\"blocks.blockid.ff_b.0.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight\":\"blocks.blockid.ff_b.2.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight\":\"blocks.blockid.ff_b.2.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight\":\"blocks.blockid.norm1_a.linear.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight\":\"blocks.blockid.norm1_a.linear.lora_B.default.weight\",\n            \"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight\":\"blocks.blockid.norm1_b.linear.lora_A.default.weight\",\n            \"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight\":\"blocks.blockid.norm1_b.linear.lora_B.default.weight\",\n        }\n\n        self.civitai_rename_dict = {\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_a.linear.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_a.linear.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_b.linear.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_b.linear.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.a_to_qkv.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.a_to_qkv.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.b_to_qkv.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.b_to_qkv.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight\": \"blocks.blockid.attn.a_to_out.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight\": \"blocks.blockid.attn.a_to_out.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight\": \"blocks.blockid.attn.b_to_out.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight\": \"blocks.blockid.attn.b_to_out.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight\": \"blocks.blockid.ff_a.0.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight\": \"blocks.blockid.ff_a.0.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight\": \"blocks.blockid.ff_a.2.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight\": \"blocks.blockid.ff_a.2.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight\": \"blocks.blockid.ff_b.0.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight\": \"blocks.blockid.ff_b.0.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight\": \"blocks.blockid.ff_b.2.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight\": \"blocks.blockid.ff_b.2.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight\": \"single_blocks.blockid.norm.linear.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight\": \"single_blocks.blockid.norm.linear.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_down.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_up.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_down.weight\": \"single_blocks.blockid.proj_out.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_up.weight\": \"single_blocks.blockid.proj_out.lora_B.default.weight\",\n        }\n\n    def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):\n        super().load(model, state_dict_lora, alpha)\n\n    \n    def convert_state_dict(self,state_dict):\n\n        def guess_block_id(name,model_resource):\n            if model_resource == 'civitai':\n                names = name.split(\"_\")\n                for i in names:\n                    if i.isdigit():\n                        return i, name.replace(f\"_{i}_\", \"_blockid_\")\n            if model_resource == 'diffusers':\n                names = name.split(\".\")\n                for i in names:\n                    if i.isdigit():\n                        return i, name.replace(f\"transformer_blocks.{i}.\", \"transformer_blocks.blockid.\")\n            return None, None\n\n        def guess_resource(state_dict):\n            for k in state_dict:\n                if \"lora_unet_\" in k:\n                    return 'civitai'\n                elif k.startswith(\"transformer.\"):\n                    return 'diffusers'\n                else:\n                    None\n        \n        model_resource = guess_resource(state_dict)\n        if model_resource is None:\n            return state_dict\n\n        rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict\n        def guess_alpha(state_dict):\n                for name, param in state_dict.items():\n                    if \".alpha\" in name:\n                        for suffix in [\".lora_down.weight\", \".lora_A.weight\"]:\n                            name_ = name.replace(\".alpha\", suffix)\n                            if name_ in state_dict:\n                                lora_alpha = param.item() / state_dict[name_].shape[0]\n                                lora_alpha = math.sqrt(lora_alpha)\n                                return lora_alpha\n\n                return 1\n        \n        alpha = guess_alpha(state_dict)\n        \n        state_dict_ = {}\n        for name, param in state_dict.items():\n            block_id, source_name = guess_block_id(name,model_resource)\n            if alpha != 1:\n                param *= alpha\n            if source_name in rename_dict:\n                target_name = rename_dict[source_name]\n                target_name = target_name.replace(\".blockid.\", f\".{block_id}.\")\n                state_dict_[target_name] = param\n            else:\n                state_dict_[name] = param\n        \n        if model_resource == 'diffusers':\n            for name in list(state_dict_.keys()):\n                if \"single_blocks.\" in name and \".a_to_q.\" in name:\n                    mlp = state_dict_.get(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"), None)\n                    if mlp is None:\n                        dim = 4\n                        if 'lora_A' in name:\n                            dim = 1\n                        mlp = torch.zeros(dim * state_dict_[name].shape[0],\n                                        *state_dict_[name].shape[1:],\n                                        dtype=state_dict_[name].dtype)\n                    else:\n                        state_dict_.pop(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"))\n                    if 'lora_A' in name:\n                        param = torch.concat([\n                            state_dict_.pop(name),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\")),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\")),\n                            mlp,\n                        ], dim=0)\n                    elif 'lora_B' in name:\n                        d, r = state_dict_[name].shape\n                        param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)\n                        param[:d, :r] = state_dict_.pop(name)\n                        param[d:2*d, r:2*r] = state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\"))\n                        param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\"))\n                        param[3*d:, 3*r:] = mlp\n                    else:\n                        param = torch.concat([\n                            state_dict_.pop(name),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\")),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\")),\n                            mlp,\n                        ], dim=0)\n                    name_ = name.replace(\".a_to_q.\", \".to_qkv_mlp.\")\n                    state_dict_[name_] = param\n            for name in list(state_dict_.keys()):\n                for component in [\"a\", \"b\"]:\n                    if f\".{component}_to_q.\" in name:\n                        name_ = name.replace(f\".{component}_to_q.\", f\".{component}_to_qkv.\")\n                        concat_dim = 0\n                        if 'lora_A' in name:\n                            param = torch.concat([\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                            ], dim=0)\n                        elif 'lora_B' in name:\n                            origin = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")]\n                            d, r = origin.shape\n                            # print(d, r)\n                            param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)\n                            param[:d, :r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")]\n                            param[d:2*d, r:2*r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")]\n                            param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")]\n                        else:\n                            param = torch.concat([\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                            ], dim=0)\n                        state_dict_[name_] = param\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\"))\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\"))\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\"))  \n        return state_dict_\n\n\nclass LoraMerger(torch.nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.weight_base = torch.nn.Parameter(torch.randn((dim,)))\n        self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))\n        self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))\n        self.weight_out = torch.nn.Parameter(torch.ones((dim,)))\n        self.bias = torch.nn.Parameter(torch.randn((dim,)))\n        self.activation = torch.nn.Sigmoid()\n        self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)\n        self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)\n        \n    def forward(self, base_output, lora_outputs):\n        norm_base_output = self.norm_base(base_output)\n        norm_lora_outputs = self.norm_lora(lora_outputs)\n        gate = self.activation(\n            norm_base_output * self.weight_base \\\n            + norm_lora_outputs * self.weight_lora \\\n            + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias\n        )\n        output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)\n        return output\n\nclass FluxLoraPatcher(torch.nn.Module):\n    def __init__(self, lora_patterns=None):\n        super().__init__()\n        if lora_patterns is None:\n            lora_patterns = self.default_lora_patterns()\n        model_dict = {}\n        for lora_pattern in lora_patterns:\n            name, dim = lora_pattern[\"name\"], lora_pattern[\"dim\"]\n            model_dict[name.replace(\".\", \"___\")] = LoraMerger(dim)\n        self.model_dict = torch.nn.ModuleDict(model_dict)\n        \n    def default_lora_patterns(self):\n        lora_patterns = []\n        lora_dict = {\n            \"attn.a_to_qkv\": 9216, \"attn.a_to_out\": 3072, \"ff_a.0\": 12288, \"ff_a.2\": 3072, \"norm1_a.linear\": 18432,\n            \"attn.b_to_qkv\": 9216, \"attn.b_to_out\": 3072, \"ff_b.0\": 12288, \"ff_b.2\": 3072, \"norm1_b.linear\": 18432,\n        }\n        for i in range(19):\n            for suffix in lora_dict:\n                lora_patterns.append({\n                    \"name\": f\"blocks.{i}.{suffix}\",\n                    \"dim\": lora_dict[suffix]\n                })\n        lora_dict = {\"to_qkv_mlp\": 21504, \"proj_out\": 3072, \"norm.linear\": 9216}\n        for i in range(38):\n            for suffix in lora_dict:\n                lora_patterns.append({\n                    \"name\": f\"single_blocks.{i}.{suffix}\",\n                    \"dim\": lora_dict[suffix]\n                })\n        return lora_patterns\n        \n    def forward(self, base_output, lora_outputs, name):\n        return self.model_dict[name.replace(\".\", \"___\")](base_output, lora_outputs)\n"
  },
  {
    "path": "diffsynth/models/flux_text_encoder_clip.py",
    "content": "import torch\n\n\nclass Attention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)\n        self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        batch_size = encoder_hidden_states.shape[0]\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n\n\nclass CLIPEncoderLayer(torch.nn.Module):\n    def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):\n        super().__init__()\n        self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)\n        self.layer_norm1 = torch.nn.LayerNorm(embed_dim)\n        self.layer_norm2 = torch.nn.LayerNorm(embed_dim)\n        self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)\n        self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)\n\n        self.use_quick_gelu = use_quick_gelu\n\n    def quickGELU(self, x):\n        return x * torch.sigmoid(1.702 * x)\n    \n    def forward(self, hidden_states, attn_mask=None):\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.attn(hidden_states, attn_mask=attn_mask)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.fc1(hidden_states)\n        if self.use_quick_gelu:\n            hidden_states = self.quickGELU(hidden_states)\n        else:\n            hidden_states = torch.nn.functional.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n    \n\nclass FluxTextEncoderClip(torch.nn.Module):\n    def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):\n        super().__init__()\n\n        # token_embedding\n        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)\n\n        # position_embeds (This is a fixed tensor)\n        self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))\n\n        # encoders\n        self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])\n\n        # attn_mask\n        self.attn_mask = self.attention_mask(max_position_embeddings)\n\n        # final_layer_norm\n        self.final_layer_norm = torch.nn.LayerNorm(embed_dim)\n\n    def attention_mask(self, length):\n        mask = torch.empty(length, length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)\n        return mask\n\n    def forward(self, input_ids, clip_skip=2, extra_mask=None):\n        embeds = self.token_embedding(input_ids)\n        embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)\n        attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)\n        if extra_mask is not None:\n            attn_mask[:, extra_mask[0]==0] = float(\"-inf\")\n        for encoder_id, encoder in enumerate(self.encoders):\n            embeds = encoder(embeds, attn_mask=attn_mask)\n            if encoder_id + clip_skip == len(self.encoders):\n                hidden_states = embeds\n        embeds = self.final_layer_norm(embeds)\n        pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]\n        return pooled_embeds, hidden_states\n"
  },
  {
    "path": "diffsynth/models/flux_text_encoder_t5.py",
    "content": "import torch\nfrom transformers import T5EncoderModel, T5Config\n\n\nclass FluxTextEncoderT5(T5EncoderModel):\n    def __init__(self):\n        config = T5Config(**{\n            \"architectures\": [\n                \"T5EncoderModel\"\n            ],\n            \"classifier_dropout\": 0.0,\n            \"d_ff\": 10240,\n            \"d_kv\": 64,\n            \"d_model\": 4096,\n            \"decoder_start_token_id\": 0,\n            \"dense_act_fn\": \"gelu_new\",\n            \"dropout_rate\": 0.1,\n            \"dtype\": \"bfloat16\",\n            \"eos_token_id\": 1,\n            \"feed_forward_proj\": \"gated-gelu\",\n            \"initializer_factor\": 1.0,\n            \"is_encoder_decoder\": True,\n            \"is_gated_act\": True,\n            \"layer_norm_epsilon\": 1e-06,\n            \"model_type\": \"t5\",\n            \"num_decoder_layers\": 24,\n            \"num_heads\": 64,\n            \"num_layers\": 24,\n            \"output_past\": True,\n            \"pad_token_id\": 0,\n            \"relative_attention_max_distance\": 128,\n            \"relative_attention_num_buckets\": 32,\n            \"tie_word_embeddings\": False,\n            \"transformers_version\": \"4.57.1\",\n            \"use_cache\": True,\n            \"vocab_size\": 32128\n        })\n        super().__init__(config)\n\n    def forward(self, input_ids):\n        outputs = super().forward(input_ids=input_ids)\n        prompt_emb = outputs.last_hidden_state\n        return prompt_emb\n"
  },
  {
    "path": "diffsynth/models/flux_vae.py",
    "content": "import torch\nfrom einops import rearrange, repeat\n\n\nclass TileWorker:\n    def __init__(self):\n        pass\n\n\n    def mask(self, height, width, border_width):\n        # Create a mask with shape (height, width).\n        # The centre area is filled with 1, and the border line is filled with values in range (0, 1].\n        x = torch.arange(height).repeat(width, 1).T\n        y = torch.arange(width).repeat(height, 1)\n        mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values\n        mask = (mask / border_width).clip(0, 1)\n        return mask\n\n\n    def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):\n        # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)\n        batch_size, channel, _, _ = model_input.shape\n        model_input = model_input.to(device=tile_device, dtype=tile_dtype)\n        unfold_operator = torch.nn.Unfold(\n            kernel_size=(tile_size, tile_size),\n            stride=(tile_stride, tile_stride)\n        )\n        model_input = unfold_operator(model_input)\n        model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))\n\n        return model_input\n\n\n    def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):\n        # Call y=forward_fn(x) for each tile\n        tile_num = model_input.shape[-1]\n        model_output_stack = []\n\n        for tile_id in range(0, tile_num, tile_batch_size):\n\n            # process input\n            tile_id_ = min(tile_id + tile_batch_size, tile_num)\n            x = model_input[:, :, :, :, tile_id: tile_id_]\n            x = x.to(device=inference_device, dtype=inference_dtype)\n            x = rearrange(x, \"b c h w n -> (n b) c h w\")\n\n            # process output\n            y = forward_fn(x)\n            y = rearrange(y, \"(n b) c h w -> b c h w n\", n=tile_id_-tile_id)\n            y = y.to(device=tile_device, dtype=tile_dtype)\n            model_output_stack.append(y)\n\n        model_output = torch.concat(model_output_stack, dim=-1)\n        return model_output\n\n\n    def io_scale(self, model_output, tile_size):\n        # Determine the size modification happened in forward_fn\n        # We only consider the same scale on height and width.\n        io_scale = model_output.shape[2] / tile_size\n        return io_scale\n    \n\n    def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):\n        # The reversed function of tile\n        mask = self.mask(tile_size, tile_size, border_width)\n        mask = mask.to(device=tile_device, dtype=tile_dtype)\n        mask = rearrange(mask, \"h w -> 1 1 h w 1\")\n        model_output = model_output * mask\n\n        fold_operator = torch.nn.Fold(\n            output_size=(height, width),\n            kernel_size=(tile_size, tile_size),\n            stride=(tile_stride, tile_stride)\n        )\n        mask = repeat(mask[0, 0, :, :, 0], \"h w -> 1 (h w) n\", n=model_output.shape[-1])\n        model_output = rearrange(model_output, \"b c h w n -> b (c h w) n\")\n        model_output = fold_operator(model_output) / fold_operator(mask)\n\n        return model_output\n\n\n    def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device=\"cpu\", tile_dtype=torch.float32, border_width=None):\n        # Prepare\n        inference_device, inference_dtype = model_input.device, model_input.dtype\n        height, width = model_input.shape[2], model_input.shape[3]\n        border_width = int(tile_stride*0.5) if border_width is None else border_width\n\n        # tile\n        model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)\n\n        # inference\n        model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)\n\n        # resize\n        io_scale = self.io_scale(model_output, tile_size)\n        height, width = int(height*io_scale), int(width*io_scale)\n        tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)\n        border_width = int(border_width*io_scale)\n\n        # untile\n        model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)\n        \n        # Done!\n        model_output = model_output.to(device=inference_device, dtype=inference_dtype)\n        return model_output\n\n\nclass ConvAttention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)\n        self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)\n        self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)\n        self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        batch_size = encoder_hidden_states.shape[0]\n\n        conv_input = rearrange(hidden_states, \"B L C -> B C L 1\")\n        q = self.to_q(conv_input)\n        q = rearrange(q[:, :, :, 0], \"B C L -> B L C\")\n        conv_input = rearrange(encoder_hidden_states, \"B L C -> B C L 1\")\n        k = self.to_k(conv_input)\n        v = self.to_v(conv_input)\n        k = rearrange(k[:, :, :, 0], \"B C L -> B L C\")\n        v = rearrange(v[:, :, :, 0], \"B C L -> B L C\")\n\n        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n\n        conv_input = rearrange(hidden_states, \"B L C -> B C L 1\")\n        hidden_states = self.to_out(conv_input)\n        hidden_states = rearrange(hidden_states[:, :, :, 0], \"B C L -> B L C\")\n\n        return hidden_states\n\n\nclass Attention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)\n        self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        batch_size = encoder_hidden_states.shape[0]\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n\n\nclass VAEAttentionBlock(torch.nn.Module):\n\n    def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):\n        super().__init__()\n        inner_dim = num_attention_heads * attention_head_dim\n\n        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)\n\n        if use_conv_attention:\n            self.transformer_blocks = torch.nn.ModuleList([\n                ConvAttention(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    bias_q=True,\n                    bias_kv=True,\n                    bias_out=True\n                )\n                for d in range(num_layers)\n            ])\n        else:\n            self.transformer_blocks = torch.nn.ModuleList([\n                Attention(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    bias_q=True,\n                    bias_kv=True,\n                    bias_out=True\n                )\n                for d in range(num_layers)\n            ])\n\n    def forward(self, hidden_states, time_emb, text_emb, res_stack):\n        batch, _, height, width = hidden_states.shape\n        residual = hidden_states\n\n        hidden_states = self.norm(hidden_states)\n        inner_dim = hidden_states.shape[1]\n        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n\n        for block in self.transformer_blocks:\n            hidden_states = block(hidden_states)\n\n        hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n        hidden_states = hidden_states + residual\n\n        return hidden_states, time_emb, text_emb, res_stack\n\n\nclass ResnetBlock(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):\n        super().__init__()\n        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n        if temb_channels is not None:\n            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)\n        self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)\n        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n        self.nonlinearity = torch.nn.SiLU()\n        self.conv_shortcut = None\n        if in_channels != out_channels:\n            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)\n\n    def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):\n        x = hidden_states\n        x = self.norm1(x)\n        x = self.nonlinearity(x)\n        x = self.conv1(x)\n        if time_emb is not None:\n            emb = self.nonlinearity(time_emb)\n            emb = self.time_emb_proj(emb)[:, :, None, None]\n            x = x + emb\n        x = self.norm2(x)\n        x = self.nonlinearity(x)\n        x = self.conv2(x)\n        if self.conv_shortcut is not None:\n            hidden_states = self.conv_shortcut(hidden_states)\n        hidden_states = hidden_states + x\n        return hidden_states, time_emb, text_emb, res_stack\n\n\nclass UpSampler(torch.nn.Module):\n    def __init__(self, channels):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)\n\n    def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):\n        hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode=\"nearest\")\n        hidden_states = self.conv(hidden_states)\n        return hidden_states, time_emb, text_emb, res_stack\n\n\nclass DownSampler(torch.nn.Module):\n    def __init__(self, channels, padding=1, extra_padding=False):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)\n        self.extra_padding = extra_padding\n\n    def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):\n        if self.extra_padding:\n            hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode=\"constant\", value=0)\n        hidden_states = self.conv(hidden_states)\n        return hidden_states, time_emb, text_emb, res_stack\n\n\nclass FluxVAEDecoder(torch.nn.Module):\n    def __init__(self, use_conv_attention=True):\n        super().__init__()\n        self.scaling_factor = 0.3611\n        self.shift_factor = 0.1159\n        self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x\n\n        self.blocks = torch.nn.ModuleList([\n            # UNetMidBlock2D\n            ResnetBlock(512, 512, eps=1e-6),\n            VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),\n            ResnetBlock(512, 512, eps=1e-6),\n            # UpDecoderBlock2D\n            ResnetBlock(512, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            UpSampler(512),\n            # UpDecoderBlock2D\n            ResnetBlock(512, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            UpSampler(512),\n            # UpDecoderBlock2D\n            ResnetBlock(512, 256, eps=1e-6),\n            ResnetBlock(256, 256, eps=1e-6),\n            ResnetBlock(256, 256, eps=1e-6),\n            UpSampler(256),\n            # UpDecoderBlock2D\n            ResnetBlock(256, 128, eps=1e-6),\n            ResnetBlock(128, 128, eps=1e-6),\n            ResnetBlock(128, 128, eps=1e-6),\n        ])\n\n        self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)\n        self.conv_act = torch.nn.SiLU()\n        self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)\n    \n    def tiled_forward(self, sample, tile_size=64, tile_stride=32):\n        hidden_states = TileWorker().tiled_forward(\n            lambda x: self.forward(x),\n            sample,\n            tile_size,\n            tile_stride,\n            tile_device=sample.device,\n            tile_dtype=sample.dtype\n        )\n        return hidden_states\n\n    def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):\n        # For VAE Decoder, we do not need to apply the tiler on each layer.\n        if tiled:\n            return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)\n\n        # 1. pre-process\n        hidden_states = sample / self.scaling_factor + self.shift_factor\n        hidden_states = self.conv_in(hidden_states)\n        time_emb = None\n        text_emb = None\n        res_stack = None\n\n        # 2. blocks\n        for i, block in enumerate(self.blocks):\n            hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)\n        \n        # 3. output\n        hidden_states = self.conv_norm_out(hidden_states)\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states)\n\n        return hidden_states\n\n\nclass FluxVAEEncoder(torch.nn.Module):\n    def __init__(self, use_conv_attention=True):\n        super().__init__()\n        self.scaling_factor = 0.3611\n        self.shift_factor = 0.1159\n        self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)\n\n        self.blocks = torch.nn.ModuleList([\n            # DownEncoderBlock2D\n            ResnetBlock(128, 128, eps=1e-6),\n            ResnetBlock(128, 128, eps=1e-6),\n            DownSampler(128, padding=0, extra_padding=True),\n            # DownEncoderBlock2D\n            ResnetBlock(128, 256, eps=1e-6),\n            ResnetBlock(256, 256, eps=1e-6),\n            DownSampler(256, padding=0, extra_padding=True),\n            # DownEncoderBlock2D\n            ResnetBlock(256, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            DownSampler(512, padding=0, extra_padding=True),\n            # DownEncoderBlock2D\n            ResnetBlock(512, 512, eps=1e-6),\n            ResnetBlock(512, 512, eps=1e-6),\n            # UNetMidBlock2D\n            ResnetBlock(512, 512, eps=1e-6),\n            VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),\n            ResnetBlock(512, 512, eps=1e-6),\n        ])\n\n        self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)\n        self.conv_act = torch.nn.SiLU()\n        self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)\n\n    def tiled_forward(self, sample, tile_size=64, tile_stride=32):\n        hidden_states = TileWorker().tiled_forward(\n            lambda x: self.forward(x),\n            sample,\n            tile_size,\n            tile_stride,\n            tile_device=sample.device,\n            tile_dtype=sample.dtype\n        )\n        return hidden_states\n\n    def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):\n        # For VAE Decoder, we do not need to apply the tiler on each layer.\n        if tiled:\n            return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)\n        \n        # 1. pre-process\n        hidden_states = self.conv_in(sample)\n        time_emb = None\n        text_emb = None\n        res_stack = None\n\n        # 2. blocks\n        for i, block in enumerate(self.blocks):\n            hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)\n        \n        # 3. output\n        hidden_states = self.conv_norm_out(hidden_states)\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states)\n        hidden_states = hidden_states[:, :16]\n        hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor\n\n        return hidden_states\n    \n    def encode_video(self, sample, batch_size=8):\n        B = sample.shape[0]\n        hidden_states = []\n\n        for i in range(0, sample.shape[2], batch_size):\n\n            j = min(i + batch_size, sample.shape[2])\n            sample_batch = rearrange(sample[:,:,i:j], \"B C T H W -> (B T) C H W\")\n\n            hidden_states_batch = self(sample_batch)\n            hidden_states_batch = rearrange(hidden_states_batch, \"(B T) C H W -> B C T H W\", B=B)\n\n            hidden_states.append(hidden_states_batch)\n        \n        hidden_states = torch.concat(hidden_states, dim=2)\n        return hidden_states\n"
  },
  {
    "path": "diffsynth/models/flux_value_control.py",
    "content": "import torch\nfrom .general_modules import TemporalTimesteps\n\n\nclass MultiValueEncoder(torch.nn.Module):\n    def __init__(self, encoders=()):\n        super().__init__()\n        if not isinstance(encoders, list):\n            encoders = [encoders]\n        self.encoders = torch.nn.ModuleList(encoders)\n\n    def __call__(self, values, dtype):\n        emb = []\n        for encoder, value in zip(self.encoders, values):\n            if value is not None:\n                value = value.unsqueeze(0)\n                emb.append(encoder(value, dtype))\n        emb = torch.concat(emb, dim=0)\n        return emb\n\n\nclass SingleValueEncoder(torch.nn.Module):\n    def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):\n        super().__init__()\n        self.prefer_len = prefer_len\n        self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)\n        self.prefer_value_embedder = torch.nn.Sequential(\n            torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)\n        )\n        self.positional_embedding = torch.nn.Parameter(\n            torch.randn(self.prefer_len, dim_out) \n        )\n\n    def forward(self, value, dtype):\n        value = value * 1000\n        emb = self.prefer_proj(value).to(dtype)\n        emb = self.prefer_value_embedder(emb).squeeze(0)\n        base_embeddings = emb.expand(self.prefer_len, -1)\n        positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)\n        learned_embeddings = base_embeddings + positional_embedding\n        return learned_embeddings\n\n    @staticmethod\n    def state_dict_converter():\n        return SingleValueEncoderStateDictConverter()\n\n\nclass SingleValueEncoderStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        return state_dict\n\n    def from_civitai(self, state_dict):\n        return state_dict\n"
  },
  {
    "path": "diffsynth/models/general_modules.py",
    "content": "import torch, math\n\n\ndef get_timestep_embedding(\n    timesteps: torch.Tensor,\n    embedding_dim: int,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 1,\n    scale: float = 1,\n    max_period: int = 10000,\n    computation_device = None,\n    align_dtype_to_timestep = False,\n):\n    assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n\n    half_dim = embedding_dim // 2\n    exponent = -math.log(max_period) * torch.arange(\n        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device\n    )\n    exponent = exponent / (half_dim - downscale_freq_shift)\n\n    emb = torch.exp(exponent)\n    if align_dtype_to_timestep:\n        emb = emb.to(timesteps.dtype)\n    emb = timesteps[:, None].float() * emb[None, :]\n\n    # scale embeddings\n    emb = scale * emb\n\n    # concat sine and cosine embeddings\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n\n    # flip sine and cosine embeddings\n    if flip_sin_to_cos:\n        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)\n\n    # zero pad\n    if embedding_dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\nclass TemporalTimesteps(torch.nn.Module):\n    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):\n        super().__init__()\n        self.num_channels = num_channels\n        self.flip_sin_to_cos = flip_sin_to_cos\n        self.downscale_freq_shift = downscale_freq_shift\n        self.computation_device = computation_device\n        self.scale = scale\n        self.align_dtype_to_timestep = align_dtype_to_timestep\n\n    def forward(self, timesteps):\n        t_emb = get_timestep_embedding(\n            timesteps,\n            self.num_channels,\n            flip_sin_to_cos=self.flip_sin_to_cos,\n            downscale_freq_shift=self.downscale_freq_shift,\n            computation_device=self.computation_device,\n            scale=self.scale,\n            align_dtype_to_timestep=self.align_dtype_to_timestep,\n        )\n        return t_emb\n\n\nclass DiffusersCompatibleTimestepProj(torch.nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.linear_1 = torch.nn.Linear(dim_in, dim_out)\n        self.act = torch.nn.SiLU()\n        self.linear_2 = torch.nn.Linear(dim_out, dim_out)\n\n    def forward(self, x):\n        x = self.linear_1(x)\n        x = self.act(x)\n        x = self.linear_2(x)\n        return x\n\n\nclass TimestepEmbeddings(torch.nn.Module):\n    def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):\n        super().__init__()\n        self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)\n        if diffusers_compatible_format:\n            self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)\n        else:\n            self.timestep_embedder = torch.nn.Sequential(\n                torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)\n            )\n        self.use_additional_t_cond = use_additional_t_cond\n        if use_additional_t_cond:\n            self.addition_t_embedding = torch.nn.Embedding(2, dim_out)\n\n    def forward(self, timestep, dtype, addition_t_cond=None):\n        time_emb = self.time_proj(timestep).to(dtype)\n        time_emb = self.timestep_embedder(time_emb)\n        if addition_t_cond is not None:\n            addition_t_emb = self.addition_t_embedding(addition_t_cond)\n            addition_t_emb = addition_t_emb.to(dtype=dtype)\n            time_emb = time_emb + addition_t_emb\n        return time_emb\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, dim, eps, elementwise_affine=True):\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.weight = torch.nn.Parameter(torch.ones((dim,)))\n        else:\n            self.weight = None\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)\n        hidden_states = hidden_states.to(input_dtype)\n        if self.weight is not None:\n            hidden_states = hidden_states * self.weight\n        return hidden_states\n\n\nclass AdaLayerNorm(torch.nn.Module):\n    def __init__(self, dim, single=False, dual=False):\n        super().__init__()\n        self.single = single\n        self.dual = dual\n        self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])\n        self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n\n    def forward(self, x, emb):\n        emb = self.linear(torch.nn.functional.silu(emb))\n        if self.single:\n            scale, shift = emb.unsqueeze(1).chunk(2, dim=2)\n            x = self.norm(x) * (1 + scale) + shift\n            return x\n        elif self.dual:\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)\n            norm_x = self.norm(x)\n            x = norm_x * (1 + scale_msa) + shift_msa\n            norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2\n            return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2\n        else:\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)\n            x = self.norm(x) * (1 + scale_msa) + shift_msa\n            return x, gate_msa, shift_mlp, scale_mlp, gate_mlp\n"
  },
  {
    "path": "diffsynth/models/longcat_video_dit.py",
    "content": "from typing import List, Optional, Tuple\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.amp as amp\n\nimport numpy as np\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom .wan_video_dit import flash_attention\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..core.gradient import gradient_checkpoint_forward\n\n\nclass RMSNorm_FP32(torch.nn.Module):\n    def __init__(self, dim: int, eps: float):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float()).type_as(x)\n        return output * self.weight\n\n\ndef broadcat(tensors, dim=-1):\n    num_tensors = len(tensors)\n    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))\n    assert len(shape_lens) == 1, \"tensors must all have the same number of dimensions\"\n    shape_len = list(shape_lens)[0]\n    dim = (dim + shape_len) if dim < 0 else dim\n    dims = list(zip(*map(lambda t: list(t.shape), tensors)))\n    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]\n    assert all(\n        [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]\n    ), \"invalid dimensions for broadcastable concatentation\"\n    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))\n    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))\n    expanded_dims.insert(dim, (dim, dims[dim]))\n    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))\n    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))\n    return torch.cat(tensors, dim=dim)\n\n\ndef rotate_half(x):\n    x = rearrange(x, \"... (d r) -> ... d r\", r=2)\n    x1, x2 = x.unbind(dim=-1)\n    x = torch.stack((-x2, x1), dim=-1)\n    return rearrange(x, \"... d r -> ... (d r)\")\n\n\nclass RotaryPositionalEmbedding(nn.Module):\n\n    def __init__(self,\n                 head_dim,\n                 cp_split_hw=None\n                 ):\n        \"\"\"Rotary positional embedding for 3D\n        Reference : https://blog.eleuther.ai/rotary-embeddings/\n        Paper: https://arxiv.org/pdf/2104.09864.pdf\n        Args:\n            dim: Dimension of embedding\n            base: Base value for exponential\n        \"\"\"\n        super().__init__()\n        self.head_dim = head_dim\n        assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'\n        self.cp_split_hw = cp_split_hw\n        # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels\n        self.base = 10000\n        self.freqs_dict = {}\n\n    def register_grid_size(self, grid_size):\n        if grid_size not in self.freqs_dict:\n            self.freqs_dict.update({\n                grid_size: self.precompute_freqs_cis_3d(grid_size)\n            })\n\n    def precompute_freqs_cis_3d(self, grid_size):\n        num_frames, height, width = grid_size     \n        dim_t = self.head_dim - 4 * (self.head_dim // 6)\n        dim_h = 2 * (self.head_dim // 6)\n        dim_w = 2 * (self.head_dim // 6)\n        freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))\n        freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))\n        freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))\n        grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)\n        grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)\n        grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)\n        grid_t = torch.from_numpy(grid_t).float()\n        grid_h = torch.from_numpy(grid_h).float()\n        grid_w = torch.from_numpy(grid_w).float()\n        freqs_t = torch.einsum(\"..., f -> ... f\", grid_t, freqs_t)\n        freqs_h = torch.einsum(\"..., f -> ... f\", grid_h, freqs_h)\n        freqs_w = torch.einsum(\"..., f -> ... f\", grid_w, freqs_w)\n        freqs_t = repeat(freqs_t, \"... n -> ... (n r)\", r=2)\n        freqs_h = repeat(freqs_h, \"... n -> ... (n r)\", r=2)\n        freqs_w = repeat(freqs_w, \"... n -> ... (n r)\", r=2)\n        freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)\n        # (T H W D)\n        freqs = rearrange(freqs, \"T H W D -> (T H W) D\")\n        # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:\n        #     with torch.no_grad():\n        #         freqs = rearrange(freqs, \"(T H W) D -> T H W D\", T=num_frames, H=height, W=width)\n        #         freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)\n        #         freqs = rearrange(freqs, \"T H W D -> (T H W) D\")\n\n        return freqs\n\n    def forward(self, q, k, grid_size):\n        \"\"\"3D RoPE.\n\n        Args:\n            query: [B, head, seq, head_dim]\n            key: [B, head, seq, head_dim]\n        Returns:\n            query and key with the same shape as input.\n        \"\"\"\n\n        if grid_size not in self.freqs_dict:\n            self.register_grid_size(grid_size)\n\n        freqs_cis = self.freqs_dict[grid_size].to(q.device)\n        q_, k_ = q.float(), k.float()\n        freqs_cis = freqs_cis.float().to(q.device)\n        cos, sin = freqs_cis.cos(), freqs_cis.sin()\n        cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')\n        q_ = (q_ * cos) + (rotate_half(q_) * sin)\n        k_ = (k_ * cos) + (rotate_half(k_) * sin)\n\n        return q_.type_as(q), k_.type_as(k)\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        enable_flashattn3: bool = False,\n        enable_flashattn2: bool = False,\n        enable_xformers: bool = False,\n        enable_bsa: bool = False,\n        bsa_params: dict = None,\n        cp_split_hw: Optional[List[int]] = None\n    ) -> None:\n        super().__init__()\n        assert dim % num_heads == 0, \"dim should be divisible by num_heads\"\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim**-0.5\n        self.enable_flashattn3 = enable_flashattn3\n        self.enable_flashattn2 = enable_flashattn2\n        self.enable_xformers = enable_xformers\n        self.enable_bsa = enable_bsa\n        self.bsa_params = bsa_params\n        self.cp_split_hw = cp_split_hw\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)\n        self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)\n        self.proj = nn.Linear(dim, dim)\n\n        self.rope_3d = RotaryPositionalEmbedding(\n            self.head_dim,\n            cp_split_hw=cp_split_hw\n        )\n\n    def _process_attn(self, q, k, v, shape):\n        q = rearrange(q, \"B H S D -> B S (H D)\")\n        k = rearrange(k, \"B H S D -> B S (H D)\")\n        v = rearrange(v, \"B H S D -> B S (H D)\")\n        x = flash_attention(q, k, v, num_heads=self.num_heads)\n        x = rearrange(x, \"B S (H D) -> B H S D\", H=self.num_heads)\n        return x\n\n    def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:\n        \"\"\"\n        \"\"\"\n        B, N, C = x.shape\n        qkv = self.qkv(x)\n\n        qkv_shape = (B, N, 3, self.num_heads, self.head_dim)\n        qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n\n        if return_kv:\n            k_cache, v_cache = k.clone(), v.clone()\n\n        q, k = self.rope_3d(q, k, shape)\n\n        # cond mode\n        if num_cond_latents is not None and num_cond_latents > 0:\n            num_cond_latents_thw = num_cond_latents * (N // shape[0])\n            # process the condition tokens\n            q_cond = q[:, :, :num_cond_latents_thw].contiguous()\n            k_cond = k[:, :, :num_cond_latents_thw].contiguous()\n            v_cond = v[:, :, :num_cond_latents_thw].contiguous()\n            x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)\n            # process the noise tokens\n            q_noise = q[:, :, num_cond_latents_thw:].contiguous()\n            x_noise = self._process_attn(q_noise, k, v, shape)\n            # merge x_cond and x_noise\n            x = torch.cat([x_cond, x_noise], dim=2).contiguous()\n        else:\n            x = self._process_attn(q, k, v, shape)\n\n        x_output_shape = (B, N, C)\n        x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]\n        x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]\n        x = self.proj(x)\n\n        if return_kv:\n            return x, (k_cache, v_cache)\n        else:\n            return x\n\n    def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:\n        \"\"\"\n        \"\"\"\n        B, N, C = x.shape\n        qkv = self.qkv(x)\n        \n        qkv_shape = (B, N, 3, self.num_heads, self.head_dim)\n        qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n\n        T, H, W = shape\n        k_cache, v_cache = kv_cache\n        assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]\n        if k_cache.shape[0] == 1:\n            k_cache = k_cache.repeat(B, 1, 1, 1)\n            v_cache = v_cache.repeat(B, 1, 1, 1)\n        \n        if num_cond_latents is not None and num_cond_latents > 0:\n            k_full = torch.cat([k_cache, k], dim=2).contiguous()\n            v_full = torch.cat([v_cache, v], dim=2).contiguous()\n            q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()\n            q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))\n            q = q_padding[:, :, -N:].contiguous()\n            \n        x = self._process_attn(q, k_full, v_full, shape)\n        \n        x_output_shape = (B, N, C)\n        x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]\n        x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]\n        x = self.proj(x)\n\n        return x\n\n\nclass MultiHeadCrossAttention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            num_heads,\n            enable_flashattn3=False,\n            enable_flashattn2=False,\n            enable_xformers=False,\n        ):\n        super(MultiHeadCrossAttention, self).__init__()\n        assert dim % num_heads == 0, \"d_model must be divisible by num_heads\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.q_linear = nn.Linear(dim, dim)\n        self.kv_linear = nn.Linear(dim, dim * 2)\n        self.proj = nn.Linear(dim, dim)\n\n        self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)\n        self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)\n\n        self.enable_flashattn3 = enable_flashattn3\n        self.enable_flashattn2 = enable_flashattn2\n        self.enable_xformers = enable_xformers\n\n    def _process_cross_attn(self, x, cond, kv_seqlen):\n        B, N, C = x.shape\n        assert C == self.dim and cond.shape[2] == self.dim\n\n        q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)\n        kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)\n        k, v = kv.unbind(2)\n\n        q, k = self.q_norm(q), self.k_norm(k)\n\n        q = rearrange(q, \"B S H D -> B S (H D)\")\n        k = rearrange(k, \"B S H D -> B S (H D)\")\n        v = rearrange(v, \"B S H D -> B S (H D)\")\n        x = flash_attention(q, k, v, num_heads=self.num_heads)\n\n        x = x.view(B, -1, C)\n        x = self.proj(x)\n        return x\n\n    def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):\n        \"\"\"\n            x: [B, N, C]\n            cond: [B, M, C]\n        \"\"\"\n        if num_cond_latents is None or num_cond_latents == 0:\n            return self._process_cross_attn(x, cond, kv_seqlen)\n        else:\n            B, N, C = x.shape\n            if num_cond_latents is not None and num_cond_latents > 0:\n                assert shape is not None, \"SHOULD pass in the shape\"\n                num_cond_latents_thw = num_cond_latents * (N // shape[0])\n                x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]\n                output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]\n                output = torch.cat([\n                    torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),\n                    output_noise\n                ], dim=1).contiguous()\n            else:\n                raise NotImplementedError\n                \n            return output\n\n\nclass LayerNorm_FP32(nn.LayerNorm):\n    def __init__(self, dim, eps, elementwise_affine):\n        super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)\n\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        origin_dtype = inputs.dtype\n        out = F.layer_norm(\n            inputs.float(), \n            self.normalized_shape, \n            None if self.weight is None else self.weight.float(), \n            None if self.bias is None else self.bias.float() ,\n            self.eps\n        ).to(origin_dtype)\n        return out\n\n\ndef modulate_fp32(norm_func, x, shift, scale):\n    # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)\n    # ensure the modulation params be fp32\n    assert shift.dtype == torch.float32, scale.dtype == torch.float32\n    dtype = x.dtype\n    x = norm_func(x.to(torch.float32))\n    x = x * (scale + 1) + shift\n    x = x.to(dtype)\n    return x\n\n\nclass FinalLayer_FP32(nn.Module):\n    \"\"\"\n    The final layer of DiT.\n    \"\"\"\n\n    def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_patch = num_patch\n        self.out_channels = out_channels\n        self.adaln_tembed_dim = adaln_tembed_dim\n\n        self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)\n        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))\n\n    def forward(self, x, t, latent_shape):\n        # timestep shape: [B, T, C]\n        assert t.dtype == torch.float32\n        B, N, C = x.shape\n        T, _, _ = latent_shape\n\n        with amp.autocast(get_device_type(), dtype=torch.float32):\n            shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]\n            x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)\n            x = self.linear(x)\n        return x\n\n\nclass FeedForwardSwiGLU(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        hidden_dim: int,\n        multiple_of: int = 256,\n        ffn_dim_multiplier: Optional[float] = None,\n    ):\n        super().__init__()\n        hidden_dim = int(2 * hidden_dim / 3)\n        # custom dim factor multiplier\n        if ffn_dim_multiplier is not None:\n            hidden_dim = int(ffn_dim_multiplier * hidden_dim)\n        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n\n        self.dim = dim\n        self.hidden_dim = hidden_dim\n        self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n        self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n        self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n\n    def forward(self, x):\n        return self.w2(F.silu(self.w1(x)) * self.w3(x))\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n\n    def __init__(self, t_embed_dim, frequency_embedding_size=256):\n        super().__init__()\n        self.t_embed_dim = t_embed_dim\n        self.frequency_embedding_size = frequency_embedding_size\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),\n            nn.SiLU(),\n            nn.Linear(t_embed_dim, t_embed_dim, bias=True),\n        )\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        half = dim // 2\n        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)\n        freqs = freqs.to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        return embedding\n\n    def forward(self, t, dtype):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        if t_freq.dtype != dtype:\n            t_freq = t_freq.to(dtype)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\n\nclass CaptionEmbedder(nn.Module):\n    \"\"\"\n    Embeds class labels into vector representations.\n    \"\"\"\n\n    def __init__(self, in_channels, hidden_size):\n        super().__init__()\n        self.in_channels = in_channels\n        self.hidden_size = hidden_size\n        self.y_proj = nn.Sequential(\n            nn.Linear(in_channels, hidden_size, bias=True),\n            nn.GELU(approximate=\"tanh\"),\n            nn.Linear(hidden_size, hidden_size, bias=True),\n        )\n\n    def forward(self, caption):\n        B, _, N, C = caption.shape\n        caption = self.y_proj(caption)\n        return caption\n\n\nclass PatchEmbed3D(nn.Module):\n    \"\"\"Video to Patch Embedding.\n\n    Args:\n        patch_size (int): Patch token size. Default: (2,4,4).\n        in_chans (int): Number of input video channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size=(2, 4, 4),\n        in_chans=3,\n        embed_dim=96,\n        norm_layer=None,\n        flatten=True,\n    ):\n        super().__init__()\n        self.patch_size = patch_size\n        self.flatten = flatten\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        # padding\n        _, _, D, H, W = x.size()\n        if W % self.patch_size[2] != 0:\n            x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))\n        if H % self.patch_size[1] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))\n        if D % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))\n\n        B, C, T, H, W = x.shape\n        x = self.proj(x)  # (B C T H W)\n        if self.norm is not None:\n            D, Wh, Ww = x.size(2), x.size(3), x.size(4)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNC\n        return x\n\n\nclass LongCatSingleStreamBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        mlp_ratio: int,\n        adaln_tembed_dim: int,\n        enable_flashattn3: bool = False,\n        enable_flashattn2: bool = False,\n        enable_xformers: bool = False,\n        enable_bsa: bool = False,\n        bsa_params=None,\n        cp_split_hw=None\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n\n        # scale and gate modulation\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)\n        )\n\n        self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)\n        self.mod_norm_ffn  = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)\n        self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)\n\n        self.attn = Attention(\n            dim=hidden_size,\n            num_heads=num_heads,\n            enable_flashattn3=enable_flashattn3,\n            enable_flashattn2=enable_flashattn2,\n            enable_xformers=enable_xformers,\n            enable_bsa=enable_bsa,\n            bsa_params=bsa_params,\n            cp_split_hw=cp_split_hw\n        )\n        self.cross_attn = MultiHeadCrossAttention(\n            dim=hidden_size,\n            num_heads=num_heads,\n            enable_flashattn3=enable_flashattn3,\n            enable_flashattn2=enable_flashattn2,\n            enable_xformers=enable_xformers,\n        )\n        self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))\n\n    def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):\n        \"\"\"\n            x: [B, N, C]\n            y: [1, N_valid_tokens, C]\n            t: [B, T, C_t]\n            y_seqlen: [B]; type of a list\n            latent_shape: latent shape of a single item\n        \"\"\"\n        x_dtype = x.dtype\n\n        B, N, C = x.shape\n        T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.\n\n        # compute modulation params in fp32\n        with amp.autocast(device_type=get_device_type(), dtype=torch.float32):\n            shift_msa, scale_msa, gate_msa, \\\n            shift_mlp, scale_mlp, gate_mlp = \\\n                self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]\n\n        # self attn with modulation\n        x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)\n\n        if kv_cache is not None:\n            kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))\n            attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)\n        else:\n            attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)\n        \n        if return_kv:\n            x_s, kv_cache = attn_outputs\n        else:\n            x_s = attn_outputs\n\n        with amp.autocast(device_type=get_device_type(), dtype=torch.float32):\n            x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]\n        x = x.to(x_dtype)\n\n        # cross attn\n        if not skip_crs_attn:\n            if kv_cache is not None:\n                num_cond_latents = None\n            x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)\n\n        # ffn with modulation\n        x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)\n        x_s = self.ffn(x_m)\n        with amp.autocast(device_type=get_device_type(), dtype=torch.float32):\n            x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]\n        x = x.to(x_dtype)\n\n        if return_kv:\n            return x, kv_cache\n        else:\n            return x\n\n\nclass LongCatVideoTransformer3DModel(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 16,\n        out_channels: int = 16,\n        hidden_size: int = 4096,\n        depth: int = 48,\n        num_heads: int = 32,\n        caption_channels: int = 4096,\n        mlp_ratio: int = 4,\n        adaln_tembed_dim: int = 512,\n        frequency_embedding_size: int = 256,\n        # default params\n        patch_size: Tuple[int] = (1, 2, 2),\n        # attention config\n        enable_flashattn3: bool = False,\n        enable_flashattn2: bool = True,\n        enable_xformers: bool = False,\n        enable_bsa: bool = False,\n        bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},\n        cp_split_hw: Optional[List[int]] = [1, 1],\n        text_tokens_zero_pad: bool = True,\n    ) -> None:\n        super().__init__()\n\n        self.patch_size = patch_size\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.cp_split_hw = cp_split_hw\n\n        self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)\n        self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)\n        self.y_embedder = CaptionEmbedder(\n            in_channels=caption_channels,\n            hidden_size=hidden_size,\n        )\n\n        self.blocks = nn.ModuleList(\n            [\n                LongCatSingleStreamBlock(\n                    hidden_size=hidden_size,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    adaln_tembed_dim=adaln_tembed_dim,\n                    enable_flashattn3=enable_flashattn3,\n                    enable_flashattn2=enable_flashattn2,\n                    enable_xformers=enable_xformers,\n                    enable_bsa=enable_bsa,\n                    bsa_params=bsa_params,\n                    cp_split_hw=cp_split_hw\n                )\n                for i in range(depth)\n            ]\n        )\n\n        self.final_layer = FinalLayer_FP32(\n            hidden_size,\n            np.prod(self.patch_size),\n            out_channels,\n            adaln_tembed_dim,\n        )\n\n        self.gradient_checkpointing = False\n        self.text_tokens_zero_pad = text_tokens_zero_pad\n\n        self.lora_dict = {}\n        self.active_loras = []\n\n    def enable_loras(self, lora_key_list=[]):\n        self.disable_all_loras()\n    \n        module_loras = {}  # {module_name: [lora1, lora2, ...]}\n        model_device = next(self.parameters()).device\n        model_dtype = next(self.parameters()).dtype\n        \n        for lora_key in lora_key_list:\n            if lora_key in self.lora_dict:\n                for lora in self.lora_dict[lora_key].loras:\n                    lora.to(model_device, dtype=model_dtype, non_blocking=True)\n                    module_name = lora.lora_name.replace(\"lora___lorahyphen___\", \"\").replace(\"___lorahyphen___\", \".\")\n                    if module_name not in module_loras:\n                        module_loras[module_name] = []\n                    module_loras[module_name].append(lora)\n                self.active_loras.append(lora_key)\n    \n        for module_name, loras in module_loras.items():\n            module = self._get_module_by_name(module_name)\n            if not hasattr(module, 'org_forward'):\n                module.org_forward = module.forward\n            module.forward = self._create_multi_lora_forward(module, loras)\n    \n    def _create_multi_lora_forward(self, module, loras):\n        def multi_lora_forward(x, *args, **kwargs):\n            weight_dtype = x.dtype\n            org_output = module.org_forward(x, *args, **kwargs)\n            \n            total_lora_output = 0\n            for lora in loras:\n                if lora.use_lora:\n                    lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))\n                    lx = lora.lora_up(lx)\n                    lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale\n                    total_lora_output += lora_output\n            \n            return org_output + total_lora_output\n        \n        return multi_lora_forward\n    \n    def _get_module_by_name(self, module_name):\n        try:\n            module = self\n            for part in module_name.split('.'):\n                module = getattr(module, part)\n            return module\n        except AttributeError as e:\n            raise ValueError(f\"Cannot find module: {module_name}, error: {e}\")\n    \n    def disable_all_loras(self):\n        for name, module in self.named_modules():\n            if hasattr(module, 'org_forward'):\n                module.forward = module.org_forward\n                delattr(module, 'org_forward')\n        \n        for lora_key, lora_network in self.lora_dict.items():\n            for lora in lora_network.loras:\n                lora.to(\"cpu\")\n        \n        self.active_loras.clear()\n\n    def enable_bsa(self,):\n        for block in self.blocks:\n            block.attn.enable_bsa = True\n    \n    def disable_bsa(self,):\n        for block in self.blocks:\n            block.attn.enable_bsa = False    \n\n    def forward(\n        self, \n        hidden_states, \n        timestep, \n        encoder_hidden_states, \n        encoder_attention_mask=None, \n        num_cond_latents=0,\n        return_kv=False, \n        kv_cache_dict={},\n        skip_crs_attn=False, \n        offload_kv_cache=False,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n\n        B, _, T, H, W = hidden_states.shape\n\n        N_t = T // self.patch_size[0]\n        N_h = H // self.patch_size[1]\n        N_w = W // self.patch_size[2]\n\n        assert self.patch_size[0]==1, \"Currently, 3D x_embedder should not compress the temporal dimension.\"\n\n        # expand the shape of timestep from [B] to [B, T]\n        if len(timestep.shape) == 1:\n            timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]\n        timestep[:, :num_cond_latents] = 0\n\n        dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(dtype)\n        timestep = timestep.to(dtype)\n        encoder_hidden_states = encoder_hidden_states.to(dtype)\n\n        hidden_states = self.x_embedder(hidden_states)  # [B, N, C]\n\n        with amp.autocast(device_type=get_device_type(), dtype=torch.float32):\n            t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1)  # [B, T, C_t]\n\n        encoder_hidden_states = self.y_embedder(encoder_hidden_states)  # [B, 1, N_token, C]\n\n        if self.text_tokens_zero_pad and encoder_attention_mask is not None:\n            encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]\n            encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)\n\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)\n            encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]\n            y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]\n        else:\n            y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]\n            encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])\n\n        # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:\n        #     hidden_states = rearrange(hidden_states, \"B (T H W) C -> B T H W C\", T=N_t, H=N_h, W=N_w)\n        #     hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)\n        #     hidden_states = rearrange(hidden_states, \"B T H W C -> B (T H W) C\")\n\n        # blocks\n        kv_cache_dict_ret = {}\n        for i, block in enumerate(self.blocks):\n            block_outputs = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                x=hidden_states,\n                y=encoder_hidden_states,\n                t=t,\n                y_seqlen=y_seqlens,\n                latent_shape=(N_t, N_h, N_w),\n                num_cond_latents=num_cond_latents,\n                return_kv=return_kv,\n                kv_cache=kv_cache_dict.get(i, None),\n                skip_crs_attn=skip_crs_attn,\n            )\n            \n            if return_kv:\n                hidden_states, kv_cache = block_outputs\n                if offload_kv_cache:\n                    kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())\n                else:\n                    kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())\n            else:\n                hidden_states = block_outputs\n\n        hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w))  # [B, N, C=T_p*H_p*W_p*C_out]\n\n        # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:\n        #     hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)\n\n        hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w)  # [B, C_out, H, W]\n\n        # cast to float32 for better accuracy\n        hidden_states = hidden_states.to(torch.float32)\n\n        if return_kv:\n            return hidden_states, kv_cache_dict_ret\n        else:\n            return hidden_states\n    \n\n    def unpatchify(self, x, N_t, N_h, N_w):\n        \"\"\"\n        Args:\n            x (torch.Tensor): of shape [B, N, C]\n\n        Return:\n            x (torch.Tensor): of shape [B, C_out, T, H, W]\n        \"\"\"\n        T_p, H_p, W_p = self.patch_size\n        x = rearrange(\n            x,\n            \"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)\",\n            N_t=N_t,\n            N_h=N_h,\n            N_w=N_w,\n            T_p=T_p,\n            H_p=H_p,\n            W_p=W_p,\n            C_out=self.out_channels,\n        )\n        return x\n\n    @staticmethod\n    def state_dict_converter():\n        return LongCatVideoTransformer3DModelDictConverter()\n\n\nclass LongCatVideoTransformer3DModelDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        return state_dict\n    \n    def from_civitai(self, state_dict):\n        return state_dict\n\n"
  },
  {
    "path": "diffsynth/models/ltx2_audio_vae.py",
    "content": "from typing import Set, Tuple, Optional, List\nfrom enum import Enum\nimport math\nimport einops\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchaudio\nfrom .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer\n\n\nclass AudioProcessor(nn.Module):\n    \"\"\"Converts audio waveforms to log-mel spectrograms with optional resampling.\"\"\"\n\n    def __init__(\n        self,\n        sample_rate: int = 16000,\n        mel_bins: int = 64,\n        mel_hop_length: int = 160,\n        n_fft: int = 1024,\n    ) -> None:\n        super().__init__()\n        self.sample_rate = sample_rate\n        self.mel_transform = torchaudio.transforms.MelSpectrogram(\n            sample_rate=sample_rate,\n            n_fft=n_fft,\n            win_length=n_fft,\n            hop_length=mel_hop_length,\n            f_min=0.0,\n            f_max=sample_rate / 2.0,\n            n_mels=mel_bins,\n            window_fn=torch.hann_window,\n            center=True,\n            pad_mode=\"reflect\",\n            power=1.0,\n            mel_scale=\"slaney\",\n            norm=\"slaney\",\n        )\n\n    def resample_waveform(\n        self,\n        waveform: torch.Tensor,\n        source_rate: int,\n        target_rate: int,\n    ) -> torch.Tensor:\n        \"\"\"Resample waveform to target sample rate if needed.\"\"\"\n        if source_rate == target_rate:\n            return waveform\n        resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)\n        return resampled.to(device=waveform.device, dtype=waveform.dtype)\n\n    def waveform_to_mel(\n        self,\n        waveform: torch.Tensor,\n        waveform_sample_rate: int,\n    ) -> torch.Tensor:\n        \"\"\"Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].\"\"\"\n        waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)\n\n        mel = self.mel_transform(waveform)\n        mel = torch.log(torch.clamp(mel, min=1e-5))\n\n        mel = mel.to(device=waveform.device, dtype=waveform.dtype)\n        return mel.permute(0, 1, 3, 2).contiguous()\n\n\nclass AudioPatchifier(Patchifier):\n    def __init__(\n        self,\n        patch_size: int,\n        sample_rate: int = 16000,\n        hop_length: int = 160,\n        audio_latent_downsample_factor: int = 4,\n        is_causal: bool = True,\n        shift: int = 0,\n    ):\n        \"\"\"\n        Patchifier tailored for spectrogram/audio latents.\n        Args:\n            patch_size: Number of mel bins combined into a single patch. This\n                controls the resolution along the frequency axis.\n            sample_rate: Original waveform sampling rate. Used to map latent\n                indices back to seconds so downstream consumers can align audio\n                and video cues.\n            hop_length: Window hop length used for the spectrogram. Determines\n                how many real-time samples separate two consecutive latent frames.\n            audio_latent_downsample_factor: Ratio between spectrogram frames and\n                latent frames; compensates for additional downsampling inside the\n                VAE encoder.\n            is_causal: When True, timing is shifted to account for causal\n                receptive fields so timestamps do not peek into the future.\n            shift: Integer offset applied to the latent indices. Enables\n                constructing overlapping windows from the same latent sequence.\n        \"\"\"\n        self.hop_length = hop_length\n        self.sample_rate = sample_rate\n        self.audio_latent_downsample_factor = audio_latent_downsample_factor\n        self.is_causal = is_causal\n        self.shift = shift\n        self._patch_size = (1, patch_size, patch_size)\n\n    @property\n    def patch_size(self) -> Tuple[int, int, int]:\n        return self._patch_size\n\n    def get_token_count(self, tgt_shape: AudioLatentShape) -> int:\n        return tgt_shape.frames\n\n    def _get_audio_latent_time_in_sec(\n        self,\n        start_latent: int,\n        end_latent: int,\n        dtype: torch.dtype,\n        device: Optional[torch.device] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Converts latent indices into real-time seconds while honoring causal\n        offsets and the configured hop length.\n        Args:\n            start_latent: Inclusive start index inside the latent sequence. This\n                sets the first timestamp returned.\n            end_latent: Exclusive end index. Determines how many timestamps get\n                generated.\n            dtype: Floating-point dtype used for the returned tensor, allowing\n                callers to control precision.\n            device: Target device for the timestamp tensor. When omitted the\n                computation occurs on CPU to avoid surprising GPU allocations.\n        \"\"\"\n        if device is None:\n            device = torch.device(\"cpu\")\n\n        audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)\n\n        audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor\n\n        if self.is_causal:\n            # Frame offset for causal alignment.\n            # The \"+1\" ensures the timestamp corresponds to the first sample that is fully available.\n            causal_offset = 1\n            audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)\n\n        return audio_mel_frame * self.hop_length / self.sample_rate\n\n    def _compute_audio_timings(\n        self,\n        batch_size: int,\n        num_steps: int,\n        device: Optional[torch.device] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.\n        This helper method underpins `get_patch_grid_bounds` for the audio patchifier.\n        Args:\n            batch_size: Number of sequences to broadcast the timings over.\n            num_steps: Number of latent frames (time steps) to convert into timestamps.\n            device: Device on which the resulting tensor should reside.\n        \"\"\"\n        resolved_device = device\n        if resolved_device is None:\n            resolved_device = torch.device(\"cpu\")\n\n        start_timings = self._get_audio_latent_time_in_sec(\n            self.shift,\n            num_steps + self.shift,\n            torch.float32,\n            resolved_device,\n        )\n        start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)\n\n        end_timings = self._get_audio_latent_time_in_sec(\n            self.shift + 1,\n            num_steps + self.shift + 1,\n            torch.float32,\n            resolved_device,\n        )\n        end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)\n\n        return torch.stack([start_timings, end_timings], dim=-1)\n\n    def patchify(\n        self,\n        audio_latents: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`\n        to derive timestamps for each latent frame based on the configured hop\n        length and downsampling.\n        Args:\n            audio_latents: Latent tensor to patchify.\n        Returns:\n            Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the\n            corresponding timing metadata when needed.\n        \"\"\"\n        audio_latents = einops.rearrange(\n            audio_latents,\n            \"b c t f -> b t (c f)\",\n        )\n\n        return audio_latents\n\n    def unpatchify(\n        self,\n        audio_latents: torch.Tensor,\n        output_shape: AudioLatentShape,\n    ) -> torch.Tensor:\n        \"\"\"\n        Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.\n        Use `get_patch_grid_bounds` to recompute the timestamps that describe each\n        frame's position in real time.\n        Args:\n            audio_latents: Latent tensor to unpatchify.\n            output_shape: Shape of the unpatched output tensor.\n        Returns:\n            Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing\n            metadata associated with the restored latents.\n        \"\"\"\n        # audio_latents shape: (batch, time, freq * channels)\n        audio_latents = einops.rearrange(\n            audio_latents,\n            \"b t (c f) -> b c t f\",\n            c=output_shape.channels,\n            f=output_shape.mel_bins,\n        )\n\n        return audio_latents\n\n    def unpatchify_audio(\n        self,\n        audio_latents: torch.Tensor,\n        channels: int,\n        mel_bins: int\n    ) -> torch.Tensor:\n        audio_latents = einops.rearrange(\n            audio_latents,\n            \"b t (c f) -> b c t f\",\n            c=channels,\n            f=mel_bins,\n        )\n        return audio_latents\n\n    def get_patch_grid_bounds(\n        self,\n        output_shape: AudioLatentShape | VideoLatentShape,\n        device: Optional[torch.device] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Return the temporal bounds `[inclusive start, exclusive end)` for every\n        patch emitted by `patchify`. For audio this corresponds to timestamps in\n        seconds aligned with the original spectrogram grid.\n        The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:\n            - axis 1 (size 1) represents the temporal dimension\n            - axis 3 (size 2) stores the `[start, end)` timestamps per patch\n        Args:\n            output_shape: Audio grid specification describing the number of time steps.\n            device: Target device for the returned tensor.\n        \"\"\"\n        if not isinstance(output_shape, AudioLatentShape):\n            raise ValueError(\"AudioPatchifier expects AudioLatentShape when computing coordinates\")\n\n        return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)\n\n\nclass AttentionType(Enum):\n    \"\"\"Enum for specifying the attention mechanism type.\"\"\"\n\n    VANILLA = \"vanilla\"\n    LINEAR = \"linear\"\n    NONE = \"none\"\n\n\nclass AttnBlock(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        norm_type: NormType = NormType.GROUP,\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = build_normalization_layer(in_channels, normtype=norm_type)\n        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, h, w = q.shape\n        q = q.reshape(b, c, h * w).contiguous()\n        q = q.permute(0, 2, 1).contiguous()  # b,hw,c\n        k = k.reshape(b, c, h * w).contiguous()  # b,c,hw\n        w_ = torch.bmm(q, k).contiguous()  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n        w_ = w_ * (int(c) ** (-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = v.reshape(b, c, h * w).contiguous()\n        w_ = w_.permute(0, 2, 1).contiguous()  # b,hw,hw (first hw of k, second of q)\n        h_ = torch.bmm(v, w_).contiguous()  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n        h_ = h_.reshape(b, c, h, w).contiguous()\n\n        h_ = self.proj_out(h_)\n\n        return x + h_\n\n\ndef make_attn(\n    in_channels: int,\n    attn_type: AttentionType = AttentionType.VANILLA,\n    norm_type: NormType = NormType.GROUP,\n) -> torch.nn.Module:\n    match attn_type:\n        case AttentionType.VANILLA:\n            return AttnBlock(in_channels, norm_type=norm_type)\n        case AttentionType.NONE:\n            return torch.nn.Identity()\n        case AttentionType.LINEAR:\n            raise NotImplementedError(f\"Attention type {attn_type.value} is not supported yet.\")\n        case _:\n            raise ValueError(f\"Unknown attention type: {attn_type}\")\n\n\nclass CausalityAxis(Enum):\n    \"\"\"Enum for specifying the causality axis in causal convolutions.\"\"\"\n\n    NONE = None\n    WIDTH = \"width\"\n    HEIGHT = \"height\"\n    WIDTH_COMPATIBILITY = \"width-compatibility\"\n\n\nclass CausalConv2d(torch.nn.Module):\n    \"\"\"\n    A causal 2D convolution.\n    This layer ensures that the output at time `t` only depends on inputs\n    at time `t` and earlier. It achieves this by applying asymmetric padding\n    to the time dimension (width) before the convolution.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | tuple[int, int],\n        stride: int = 1,\n        dilation: int | tuple[int, int] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        causality_axis: CausalityAxis = CausalityAxis.HEIGHT,\n    ) -> None:\n        super().__init__()\n\n        self.causality_axis = causality_axis\n\n        # Ensure kernel_size and dilation are tuples\n        kernel_size = torch.nn.modules.utils._pair(kernel_size)\n        dilation = torch.nn.modules.utils._pair(dilation)\n\n        # Calculate padding dimensions\n        pad_h = (kernel_size[0] - 1) * dilation[0]\n        pad_w = (kernel_size[1] - 1) * dilation[1]\n\n        # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)\n        match self.causality_axis:\n            case CausalityAxis.NONE:\n                self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)\n            case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:\n                self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)\n            case CausalityAxis.HEIGHT:\n                self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)\n            case _:\n                raise ValueError(f\"Invalid causality_axis: {causality_axis}\")\n\n        # The internal convolution layer uses no padding, as we handle it manually\n        self.conv = torch.nn.Conv2d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=0,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # Apply causal padding before convolution\n        x = F.pad(x, self.padding)\n        return self.conv(x)\n\n\ndef make_conv2d(\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int | tuple[int, int],\n    stride: int = 1,\n    padding: tuple[int, int, int, int] | None = None,\n    dilation: int = 1,\n    groups: int = 1,\n    bias: bool = True,\n    causality_axis: CausalityAxis | None = None,\n) -> torch.nn.Module:\n    \"\"\"\n    Create a 2D convolution layer that can be either causal or non-causal.\n    Args:\n        in_channels: Number of input channels\n        out_channels: Number of output channels\n        kernel_size: Size of the convolution kernel\n        stride: Convolution stride\n        padding: Padding (if None, will be calculated based on causal flag)\n        dilation: Dilation rate\n        groups: Number of groups for grouped convolution\n        bias: Whether to use bias\n        causality_axis: Dimension along which to apply causality.\n    Returns:\n        Either a regular Conv2d or CausalConv2d layer\n    \"\"\"\n    if causality_axis is not None:\n        # For causal convolution, padding is handled internally by CausalConv2d\n        return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)\n    else:\n        # For non-causal convolution, use symmetric padding if not specified\n        if padding is None:\n            padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)\n\n        return torch.nn.Conv2d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n        )\n\n\n\nLRELU_SLOPE = 0.1\n\n\nclass ResBlock1(torch.nn.Module):\n    def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):\n        super(ResBlock1, self).__init__()\n        self.convs1 = torch.nn.ModuleList(\n            [\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[0],\n                    padding=\"same\",\n                ),\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[1],\n                    padding=\"same\",\n                ),\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[2],\n                    padding=\"same\",\n                ),\n            ]\n        )\n\n        self.convs2 = torch.nn.ModuleList(\n            [\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=1,\n                    padding=\"same\",\n                ),\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=1,\n                    padding=\"same\",\n                ),\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=1,\n                    padding=\"same\",\n                ),\n            ]\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):\n            xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)\n            xt = conv1(xt)\n            xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)\n            xt = conv2(xt)\n            x = xt + x\n        return x\n\n\nclass ResBlock2(torch.nn.Module):\n    def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):\n        super(ResBlock2, self).__init__()\n        self.convs = torch.nn.ModuleList(\n            [\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[0],\n                    padding=\"same\",\n                ),\n                torch.nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[1],\n                    padding=\"same\",\n                ),\n            ]\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for conv in self.convs:\n            xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)\n            xt = conv(xt)\n            x = xt + x\n        return x\n\n\nclass ResnetBlock(torch.nn.Module):\n    def __init__(\n        self,\n        *,\n        in_channels: int,\n        out_channels: int | None = None,\n        conv_shortcut: bool = False,\n        dropout: float = 0.0,\n        temb_channels: int = 512,\n        norm_type: NormType = NormType.GROUP,\n        causality_axis: CausalityAxis = CausalityAxis.HEIGHT,\n    ) -> None:\n        super().__init__()\n        self.causality_axis = causality_axis\n\n        if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:\n            raise ValueError(\"Causal ResnetBlock with GroupNorm is not supported.\")\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)\n        self.non_linearity = torch.nn.SiLU()\n        self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)\n        self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = make_conv2d(\n                    in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis\n                )\n            else:\n                self.nin_shortcut = make_conv2d(\n                    in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis\n                )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        temb: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        h = x\n        h = self.norm1(h)\n        h = self.non_linearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]\n\n        h = self.norm2(h)\n        h = self.non_linearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)\n\n        return x + h\n\n\nclass Downsample(torch.nn.Module):\n    \"\"\"\n    A downsampling layer that can use either a strided convolution\n    or average pooling. Supports standard and causal padding for the\n    convolutional mode.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        with_conv: bool,\n        causality_axis: CausalityAxis = CausalityAxis.WIDTH,\n    ) -> None:\n        super().__init__()\n        self.with_conv = with_conv\n        self.causality_axis = causality_axis\n\n        if self.causality_axis != CausalityAxis.NONE and not self.with_conv:\n            raise ValueError(\"causality is only supported when `with_conv=True`.\")\n\n        if self.with_conv:\n            # Do time downsampling here\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.with_conv:\n            # Padding tuple is in the order: (left, right, top, bottom).\n            match self.causality_axis:\n                case CausalityAxis.NONE:\n                    pad = (0, 1, 0, 1)\n                case CausalityAxis.WIDTH:\n                    pad = (2, 0, 0, 1)\n                case CausalityAxis.HEIGHT:\n                    pad = (0, 1, 2, 0)\n                case CausalityAxis.WIDTH_COMPATIBILITY:\n                    pad = (1, 0, 0, 1)\n                case _:\n                    raise ValueError(f\"Invalid causality_axis: {self.causality_axis}\")\n\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            # This branch is only taken if with_conv=False, which implies causality_axis is NONE.\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n\n        return x\n\n\ndef build_downsampling_path(  # noqa: PLR0913\n    *,\n    ch: int,\n    ch_mult: Tuple[int, ...],\n    num_resolutions: int,\n    num_res_blocks: int,\n    resolution: int,\n    temb_channels: int,\n    dropout: float,\n    norm_type: NormType,\n    causality_axis: CausalityAxis,\n    attn_type: AttentionType,\n    attn_resolutions: Set[int],\n    resamp_with_conv: bool,\n) -> tuple[torch.nn.ModuleList, int]:\n    \"\"\"Build the downsampling path with residual blocks, attention, and downsampling layers.\"\"\"\n    down_modules = torch.nn.ModuleList()\n    curr_res = resolution\n    in_ch_mult = (1, *tuple(ch_mult))\n    block_in = ch\n\n    for i_level in range(num_resolutions):\n        block = torch.nn.ModuleList()\n        attn = torch.nn.ModuleList()\n        block_in = ch * in_ch_mult[i_level]\n        block_out = ch * ch_mult[i_level]\n\n        for _ in range(num_res_blocks):\n            block.append(\n                ResnetBlock(\n                    in_channels=block_in,\n                    out_channels=block_out,\n                    temb_channels=temb_channels,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                    causality_axis=causality_axis,\n                )\n            )\n            block_in = block_out\n            if curr_res in attn_resolutions:\n                attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))\n\n        down = torch.nn.Module()\n        down.block = block\n        down.attn = attn\n        if i_level != num_resolutions - 1:\n            down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)\n            curr_res = curr_res // 2\n        down_modules.append(down)\n\n    return down_modules, block_in\n\n\nclass Upsample(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        with_conv: bool,\n        causality_axis: CausalityAxis = CausalityAxis.HEIGHT,\n    ) -> None:\n        super().__init__()\n        self.with_conv = with_conv\n        self.causality_axis = causality_axis\n        if self.with_conv:\n            self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n            # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.\n            # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].\n            # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],\n            # So the output elements rely on the following windows:\n            # 0: [-,-,0]\n            # 1: [-,0,0]\n            # 2: [0,0,1]\n            # 3: [0,1,1]\n            # 4: [1,1,2]\n            # 5: [1,2,2]\n            # Notice that the first and second elements in the output rely only on the first element in the input,\n            # while all other elements rely on two elements in the input.\n            # So we can drop the first element to undo the padding (rather than the last element).\n            # This is a no-op for non-causal convolutions.\n            match self.causality_axis:\n                case CausalityAxis.NONE:\n                    pass  # x remains unchanged\n                case CausalityAxis.HEIGHT:\n                    x = x[:, :, 1:, :]\n                case CausalityAxis.WIDTH:\n                    x = x[:, :, :, 1:]\n                case CausalityAxis.WIDTH_COMPATIBILITY:\n                    pass  # x remains unchanged\n                case _:\n                    raise ValueError(f\"Invalid causality_axis: {self.causality_axis}\")\n\n        return x\n\n\ndef build_upsampling_path(  # noqa: PLR0913\n    *,\n    ch: int,\n    ch_mult: Tuple[int, ...],\n    num_resolutions: int,\n    num_res_blocks: int,\n    resolution: int,\n    temb_channels: int,\n    dropout: float,\n    norm_type: NormType,\n    causality_axis: CausalityAxis,\n    attn_type: AttentionType,\n    attn_resolutions: Set[int],\n    resamp_with_conv: bool,\n    initial_block_channels: int,\n) -> tuple[torch.nn.ModuleList, int]:\n    \"\"\"Build the upsampling path with residual blocks, attention, and upsampling layers.\"\"\"\n    up_modules = torch.nn.ModuleList()\n    block_in = initial_block_channels\n    curr_res = resolution // (2 ** (num_resolutions - 1))\n\n    for level in reversed(range(num_resolutions)):\n        stage = torch.nn.Module()\n        stage.block = torch.nn.ModuleList()\n        stage.attn = torch.nn.ModuleList()\n        block_out = ch * ch_mult[level]\n\n        for _ in range(num_res_blocks + 1):\n            stage.block.append(\n                ResnetBlock(\n                    in_channels=block_in,\n                    out_channels=block_out,\n                    temb_channels=temb_channels,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                    causality_axis=causality_axis,\n                )\n            )\n            block_in = block_out\n            if curr_res in attn_resolutions:\n                stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))\n\n        if level != 0:\n            stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)\n            curr_res *= 2\n\n        up_modules.insert(0, stage)\n\n    return up_modules, block_in\n\n\nclass PerChannelStatistics(nn.Module):\n    \"\"\"\n    Per-channel statistics for normalizing and denormalizing the latent representation.\n    This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.\n    \"\"\"\n\n    def __init__(self, latent_channels: int = 128) -> None:\n        super().__init__()\n        self.register_buffer(\"std-of-means\", torch.empty(latent_channels))\n        self.register_buffer(\"mean-of-means\", torch.empty(latent_channels))\n\n    def un_normalize(self, x: torch.Tensor) -> torch.Tensor:\n        return (x * self.get_buffer(\"std-of-means\").to(x)) + self.get_buffer(\"mean-of-means\").to(x)\n\n    def normalize(self, x: torch.Tensor) -> torch.Tensor:\n        return (x - self.get_buffer(\"mean-of-means\").to(x)) / self.get_buffer(\"std-of-means\").to(x)\n\n\nLATENT_DOWNSAMPLE_FACTOR = 4\n\n\ndef build_mid_block(\n    channels: int,\n    temb_channels: int,\n    dropout: float,\n    norm_type: NormType,\n    causality_axis: CausalityAxis,\n    attn_type: AttentionType,\n    add_attention: bool,\n) -> torch.nn.Module:\n    \"\"\"Build the middle block with two ResNet blocks and optional attention.\"\"\"\n    mid = torch.nn.Module()\n    mid.block_1 = ResnetBlock(\n        in_channels=channels,\n        out_channels=channels,\n        temb_channels=temb_channels,\n        dropout=dropout,\n        norm_type=norm_type,\n        causality_axis=causality_axis,\n    )\n    mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()\n    mid.block_2 = ResnetBlock(\n        in_channels=channels,\n        out_channels=channels,\n        temb_channels=temb_channels,\n        dropout=dropout,\n        norm_type=norm_type,\n        causality_axis=causality_axis,\n    )\n    return mid\n\n\ndef run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:\n    \"\"\"Run features through the middle block.\"\"\"\n    features = mid.block_1(features, temb=None)\n    features = mid.attn_1(features)\n    return mid.block_2(features, temb=None)\n\n\nclass LTX2AudioEncoder(torch.nn.Module):\n    \"\"\"\n    Encoder that compresses audio spectrograms into latent representations.\n    The encoder uses a series of downsampling blocks with residual connections,\n    attention mechanisms, and configurable causal convolutions.\n    \"\"\"\n\n    def __init__(  # noqa: PLR0913\n        self,\n        *,\n        ch: int = 128,\n        ch_mult: Tuple[int, ...] = (1, 2, 4),\n        num_res_blocks: int = 2,\n        attn_resolutions: Set[int] = set(),\n        dropout: float = 0.0,\n        resamp_with_conv: bool = True,\n        in_channels: int = 2,\n        resolution: int = 256,\n        z_channels: int = 8,\n        double_z: bool = True,\n        attn_type: AttentionType = AttentionType.VANILLA,\n        mid_block_add_attention: bool = False,\n        norm_type: NormType = NormType.PIXEL,\n        causality_axis: CausalityAxis = CausalityAxis.HEIGHT,\n        sample_rate: int = 16000,\n        mel_hop_length: int = 160,\n        n_fft: int = 1024,\n        is_causal: bool = True,\n        mel_bins: int = 64,\n        **_ignore_kwargs,\n    ) -> None:\n        \"\"\"\n        Initialize the Encoder.\n        Args:\n            Arguments are configuration parameters, loaded from the audio VAE checkpoint config\n            (audio_vae.model.params.ddconfig):\n            ch: Base number of feature channels used in the first convolution layer.\n            ch_mult: Multiplicative factors for the number of channels at each resolution level.\n            num_res_blocks: Number of residual blocks to use at each resolution level.\n            attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.\n            resolution: Input spatial resolution of the spectrogram (height, width).\n            z_channels: Number of channels in the latent representation.\n            norm_type: Normalization layer type to use within the network (e.g., group, batch).\n            causality_axis: Axis along which convolutions should be causal (e.g., time axis).\n            sample_rate: Audio sample rate in Hz for the input signals.\n            mel_hop_length: Hop length used when computing the mel spectrogram.\n            n_fft: FFT size used to compute the spectrogram.\n            mel_bins: Number of mel-frequency bins in the input spectrogram.\n            in_channels: Number of channels in the input spectrogram tensor.\n            double_z: If True, predict both mean and log-variance (doubling latent channels).\n            is_causal: If True, use causal convolutions suitable for streaming setups.\n            dropout: Dropout probability used in residual and mid blocks.\n            attn_type: Type of attention mechanism to use in attention blocks.\n            resamp_with_conv: If True, perform resolution changes using strided convolutions.\n            mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.\n        \"\"\"\n        super().__init__()\n\n        self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)\n        self.sample_rate = sample_rate\n        self.mel_hop_length = mel_hop_length\n        self.n_fft = n_fft\n        self.is_causal = is_causal\n        self.mel_bins = mel_bins\n\n        self.patchifier = AudioPatchifier(\n            patch_size=1,\n            audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,\n            sample_rate=sample_rate,\n            hop_length=mel_hop_length,\n            is_causal=is_causal,\n        )\n\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.z_channels = z_channels\n        self.double_z = double_z\n        self.norm_type = norm_type\n        self.causality_axis = causality_axis\n        self.attn_type = attn_type\n\n        # downsampling\n        self.conv_in = make_conv2d(\n            in_channels,\n            self.ch,\n            kernel_size=3,\n            stride=1,\n            causality_axis=self.causality_axis,\n        )\n\n        self.non_linearity = torch.nn.SiLU()\n\n        self.down, block_in = build_downsampling_path(\n            ch=ch,\n            ch_mult=ch_mult,\n            num_resolutions=self.num_resolutions,\n            num_res_blocks=num_res_blocks,\n            resolution=resolution,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n            attn_type=self.attn_type,\n            attn_resolutions=attn_resolutions,\n            resamp_with_conv=resamp_with_conv,\n        )\n\n        self.mid = build_mid_block(\n            channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n            attn_type=self.attn_type,\n            add_attention=mid_block_add_attention,\n        )\n\n        self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)\n        self.conv_out = make_conv2d(\n            block_in,\n            2 * z_channels if double_z else z_channels,\n            kernel_size=3,\n            stride=1,\n            causality_axis=self.causality_axis,\n        )\n\n    def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Encode audio spectrogram into latent representations.\n        Args:\n            spectrogram: Input spectrogram of shape (batch, channels, time, frequency)\n        Returns:\n            Encoded latent representation of shape (batch, channels, frames, mel_bins)\n        \"\"\"\n        h = self.conv_in(spectrogram)\n        h = self._run_downsampling_path(h)\n        h = run_mid_block(self.mid, h)\n        h = self._finalize_output(h)\n\n        return self._normalize_latents(h)\n\n    def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:\n        for level in range(self.num_resolutions):\n            stage = self.down[level]\n            for block_idx in range(self.num_res_blocks):\n                h = stage.block[block_idx](h, temb=None)\n                if stage.attn:\n                    h = stage.attn[block_idx](h)\n\n            if level != self.num_resolutions - 1:\n                h = stage.downsample(h)\n\n        return h\n\n    def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:\n        h = self.norm_out(h)\n        h = self.non_linearity(h)\n        return self.conv_out(h)\n\n    def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Normalize encoder latents using per-channel statistics.\n        When the encoder is configured with ``double_z=True``, the final\n        convolution produces twice the number of latent channels, typically\n        interpreted as two concatenated tensors along the channel dimension\n        (e.g., mean and variance or other auxiliary parameters).\n        This method intentionally uses only the first half of the channels\n        (the \"mean\" component) as input to the patchifier and normalization\n        logic. The remaining channels are left unchanged by this method and\n        are expected to be consumed elsewhere in the VAE pipeline.\n        If ``double_z=False``, the encoder output already contains only the\n        mean latents and the chunking operation simply returns that tensor.\n        \"\"\"\n        means = torch.chunk(latent_output, 2, dim=1)[0]\n        latent_shape = AudioLatentShape(\n            batch=means.shape[0],\n            channels=means.shape[1],\n            frames=means.shape[2],\n            mel_bins=means.shape[3],\n        )\n        latent_patched = self.patchifier.patchify(means)\n        latent_normalized = self.per_channel_statistics.normalize(latent_patched)\n        return self.patchifier.unpatchify(latent_normalized, latent_shape)\n\n\nclass LTX2AudioDecoder(torch.nn.Module):\n    \"\"\"\n    Symmetric decoder that reconstructs audio spectrograms from latent features.\n    The decoder mirrors the encoder structure with configurable channel multipliers,\n    attention resolutions, and causal convolutions.\n    \"\"\"\n\n    def __init__(  # noqa: PLR0913\n        self,\n        *,\n        ch: int = 128,\n        out_ch: int = 2,\n        ch_mult: Tuple[int, ...] = (1, 2, 4),\n        num_res_blocks: int = 2,\n        attn_resolutions: Set[int] = set(),\n        resolution: int=256,\n        z_channels: int=8,\n        norm_type: NormType = NormType.PIXEL,\n        causality_axis: CausalityAxis = CausalityAxis.HEIGHT,\n        dropout: float = 0.0,\n        mid_block_add_attention: bool = False,\n        sample_rate: int = 16000,\n        mel_hop_length: int = 160,\n        is_causal: bool = True,\n        mel_bins: int | None = 64,\n    ) -> None:\n        \"\"\"\n        Initialize the Decoder.\n        Args:\n            Arguments are configuration parameters, loaded from the audio VAE checkpoint config\n            (audio_vae.model.params.ddconfig):\n            - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions\n            - resolution, z_channels\n            - norm_type, causality_axis\n        \"\"\"\n        super().__init__()\n\n        # Internal behavioural defaults that are not driven by the checkpoint.\n        resamp_with_conv = True\n        attn_type = AttentionType.VANILLA\n\n        # Per-channel statistics for denormalizing latents\n        self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)\n        self.sample_rate = sample_rate\n        self.mel_hop_length = mel_hop_length\n        self.is_causal = is_causal\n        self.mel_bins = mel_bins\n        self.patchifier = AudioPatchifier(\n            patch_size=1,\n            audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,\n            sample_rate=sample_rate,\n            hop_length=mel_hop_length,\n            is_causal=is_causal,\n        )\n\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.out_ch = out_ch\n        self.give_pre_end = False\n        self.tanh_out = False\n        self.norm_type = norm_type\n        self.z_channels = z_channels\n        self.channel_multipliers = ch_mult\n        self.attn_resolutions = attn_resolutions\n        self.causality_axis = causality_axis\n        self.attn_type = attn_type\n\n        base_block_channels = ch * self.channel_multipliers[-1]\n        base_resolution = resolution // (2 ** (self.num_resolutions - 1))\n        self.z_shape = (1, z_channels, base_resolution, base_resolution)\n\n        self.conv_in = make_conv2d(\n            z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis\n        )\n        self.non_linearity = torch.nn.SiLU()\n        self.mid = build_mid_block(\n            channels=base_block_channels,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n            attn_type=self.attn_type,\n            add_attention=mid_block_add_attention,\n        )\n        self.up, final_block_channels = build_upsampling_path(\n            ch=ch,\n            ch_mult=ch_mult,\n            num_resolutions=self.num_resolutions,\n            num_res_blocks=num_res_blocks,\n            resolution=resolution,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n            attn_type=self.attn_type,\n            attn_resolutions=attn_resolutions,\n            resamp_with_conv=resamp_with_conv,\n            initial_block_channels=base_block_channels,\n        )\n\n        self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)\n        self.conv_out = make_conv2d(\n            final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis\n        )\n\n    def forward(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Decode latent features back to audio spectrograms.\n        Args:\n            sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)\n        Returns:\n            Reconstructed audio spectrogram of shape (batch, channels, time, frequency)\n        \"\"\"\n        sample, target_shape = self._denormalize_latents(sample)\n\n        h = self.conv_in(sample)\n        h = run_mid_block(self.mid, h)\n        h = self._run_upsampling_path(h)\n        h = self._finalize_output(h)\n\n        return self._adjust_output_shape(h, target_shape)\n\n    def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:\n        latent_shape = AudioLatentShape(\n            batch=sample.shape[0],\n            channels=sample.shape[1],\n            frames=sample.shape[2],\n            mel_bins=sample.shape[3],\n        )\n\n        sample_patched = self.patchifier.patchify(sample)\n        sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)\n        sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)\n\n        target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR\n        if self.causality_axis != CausalityAxis.NONE:\n            target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)\n\n        target_shape = AudioLatentShape(\n            batch=latent_shape.batch,\n            channels=self.out_ch,\n            frames=target_frames,\n            mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,\n        )\n\n        return sample, target_shape\n\n    def _adjust_output_shape(\n        self,\n        decoded_output: torch.Tensor,\n        target_shape: AudioLatentShape,\n    ) -> torch.Tensor:\n        \"\"\"\n        Adjust output shape to match target dimensions for variable-length audio.\n        This function handles the common case where decoded audio spectrograms need to be\n        resized to match a specific target shape.\n        Args:\n            decoded_output: Tensor of shape (batch, channels, time, frequency)\n            target_shape: AudioLatentShape describing (batch, channels, time, mel bins)\n        Returns:\n            Tensor adjusted to match target_shape exactly\n        \"\"\"\n        # Current output shape: (batch, channels, time, frequency)\n        _, _, current_time, current_freq = decoded_output.shape\n        target_channels = target_shape.channels\n        target_time = target_shape.frames\n        target_freq = target_shape.mel_bins\n\n        # Step 1: Crop first to avoid exceeding target dimensions\n        decoded_output = decoded_output[\n            :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)\n        ]\n\n        # Step 2: Calculate padding needed for time and frequency dimensions\n        time_padding_needed = target_time - decoded_output.shape[2]\n        freq_padding_needed = target_freq - decoded_output.shape[3]\n\n        # Step 3: Apply padding if needed\n        if time_padding_needed > 0 or freq_padding_needed > 0:\n            # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)\n            # For audio: pad_left/right = frequency, pad_top/bottom = time\n            padding = (\n                0,\n                max(freq_padding_needed, 0),  # frequency padding (left, right)\n                0,\n                max(time_padding_needed, 0),  # time padding (top, bottom)\n            )\n            decoded_output = F.pad(decoded_output, padding)\n\n        # Step 4: Final safety crop to ensure exact target shape\n        decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]\n\n        return decoded_output\n\n    def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:\n        for level in reversed(range(self.num_resolutions)):\n            stage = self.up[level]\n            for block_idx, block in enumerate(stage.block):\n                h = block(h, temb=None)\n                if stage.attn:\n                    h = stage.attn[block_idx](h)\n\n            if level != 0 and hasattr(stage, \"upsample\"):\n                h = stage.upsample(h)\n\n        return h\n\n    def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = self.non_linearity(h)\n        h = self.conv_out(h)\n        return torch.tanh(h) if self.tanh_out else h\n\n\ndef get_padding(kernel_size: int, dilation: int = 1) -> int:\n    return int((kernel_size * dilation - dilation) / 2)\n\n\n# ---------------------------------------------------------------------------\n# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2\n# Adopted from https://github.com/NVIDIA/BigVGAN\n# ---------------------------------------------------------------------------\n\n\ndef _sinc(x: torch.Tensor) -> torch.Tensor:\n    return torch.where(\n        x == 0,\n        torch.tensor(1.0, device=x.device, dtype=x.dtype),\n        torch.sin(math.pi * x) / math.pi / x,\n    )\n\n\ndef kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:\n    even = kernel_size % 2 == 0\n    half_size = kernel_size // 2\n    delta_f = 4 * half_width\n    amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95\n    if amplitude > 50.0:\n        beta = 0.1102 * (amplitude - 8.7)\n    elif amplitude >= 21.0:\n        beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)\n    else:\n        beta = 0.0\n    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)\n    time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size\n    if cutoff == 0:\n        filter_ = torch.zeros_like(time)\n    else:\n        filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)\n        filter_ /= filter_.sum()\n    return filter_.view(1, 1, kernel_size)\n\n\nclass LowPassFilter1d(nn.Module):\n    def __init__(\n        self,\n        cutoff: float = 0.5,\n        half_width: float = 0.6,\n        stride: int = 1,\n        padding: bool = True,\n        padding_mode: str = \"replicate\",\n        kernel_size: int = 12,\n    ) -> None:\n        super().__init__()\n        if cutoff < -0.0:\n            raise ValueError(\"Minimum cutoff must be larger than zero.\")\n        if cutoff > 0.5:\n            raise ValueError(\"A cutoff above 0.5 does not make sense.\")\n        self.kernel_size = kernel_size\n        self.even = kernel_size % 2 == 0\n        self.pad_left = kernel_size // 2 - int(self.even)\n        self.pad_right = kernel_size // 2\n        self.stride = stride\n        self.padding = padding\n        self.padding_mode = padding_mode\n        self.register_buffer(\"filter\", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        _, n_channels, _ = x.shape\n        if self.padding:\n            x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)\n        return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)\n\n\nclass UpSample1d(nn.Module):\n    def __init__(\n        self,\n        ratio: int = 2,\n        kernel_size: int | None = None,\n        persistent: bool = True,\n        window_type: str = \"kaiser\",\n    ) -> None:\n        super().__init__()\n        self.ratio = ratio\n        self.stride = ratio\n\n        if window_type == \"hann\":\n            # Hann-windowed sinc filter equivalent to torchaudio.functional.resample\n            rolloff = 0.99\n            lowpass_filter_width = 6\n            width = math.ceil(lowpass_filter_width / rolloff)\n            self.kernel_size = 2 * width * ratio + 1\n            self.pad = width\n            self.pad_left = 2 * width * ratio\n            self.pad_right = self.kernel_size - ratio\n            time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff\n            time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)\n            window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2\n            sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)\n        else:\n            # Kaiser-windowed sinc filter (BigVGAN default).\n            self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n            self.pad = self.kernel_size // ratio - 1\n            self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2\n            self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2\n            sinc_filter = kaiser_sinc_filter1d(\n                cutoff=0.5 / ratio,\n                half_width=0.6 / ratio,\n                kernel_size=self.kernel_size,\n            )\n\n        self.register_buffer(\"filter\", sinc_filter, persistent=persistent)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        _, n_channels, _ = x.shape\n        x = F.pad(x, (self.pad, self.pad), mode=\"replicate\")\n        filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)\n        x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)\n        return x[..., self.pad_left : -self.pad_right]\n\n\nclass DownSample1d(nn.Module):\n    def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:\n        super().__init__()\n        self.ratio = ratio\n        self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n        self.lowpass = LowPassFilter1d(\n            cutoff=0.5 / ratio,\n            half_width=0.6 / ratio,\n            stride=ratio,\n            kernel_size=self.kernel_size,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.lowpass(x)\n\n\nclass Activation1d(nn.Module):\n    def __init__(\n        self,\n        activation: nn.Module,\n        up_ratio: int = 2,\n        down_ratio: int = 2,\n        up_kernel_size: int = 12,\n        down_kernel_size: int = 12,\n    ) -> None:\n        super().__init__()\n        self.act = activation\n        self.upsample = UpSample1d(up_ratio, up_kernel_size)\n        self.downsample = DownSample1d(down_ratio, down_kernel_size)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.upsample(x)\n        x = self.act(x)\n        return self.downsample(x)\n\n\nclass Snake(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        alpha: float = 1.0,\n        alpha_trainable: bool = True,\n        alpha_logscale: bool = True,\n    ) -> None:\n        super().__init__()\n        self.alpha_logscale = alpha_logscale\n        self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)\n        self.alpha.requires_grad = alpha_trainable\n        self.eps = 1e-9\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)\n        if self.alpha_logscale:\n            alpha = torch.exp(alpha)\n        return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)\n\n\nclass SnakeBeta(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        alpha: float = 1.0,\n        alpha_trainable: bool = True,\n        alpha_logscale: bool = True,\n    ) -> None:\n        super().__init__()\n        self.alpha_logscale = alpha_logscale\n        self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)\n        self.alpha.requires_grad = alpha_trainable\n        self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)\n        self.beta.requires_grad = alpha_trainable\n        self.eps = 1e-9\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)\n        beta = self.beta.unsqueeze(0).unsqueeze(-1)\n        if self.alpha_logscale:\n            alpha = torch.exp(alpha)\n            beta = torch.exp(beta)\n        return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)\n\n\nclass AMPBlock1(nn.Module):\n    def __init__(\n        self,\n        channels: int,\n        kernel_size: int = 3,\n        dilation: tuple[int, int, int] = (1, 3, 5),\n        activation: str = \"snake\",\n    ) -> None:\n        super().__init__()\n        act_cls = SnakeBeta if activation == \"snakebeta\" else Snake\n        self.convs1 = nn.ModuleList(\n            [\n                nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[0],\n                    padding=get_padding(kernel_size, dilation[0]),\n                ),\n                nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[1],\n                    padding=get_padding(kernel_size, dilation[1]),\n                ),\n                nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    1,\n                    dilation=dilation[2],\n                    padding=get_padding(kernel_size, dilation[2]),\n                ),\n            ]\n        )\n\n        self.convs2 = nn.ModuleList(\n            [\n                nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),\n                nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),\n                nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),\n            ]\n        )\n\n        self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])\n        self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):\n            xt = a1(x)\n            xt = c1(xt)\n            xt = a2(xt)\n            xt = c2(xt)\n            x = x + xt\n        return x\n\n\nclass LTX2Vocoder(torch.nn.Module):\n    \"\"\"\n    LTX2Vocoder model for synthesizing audio from Mel spectrograms.\n    Args:\n        resblock_kernel_sizes: List of kernel sizes for the residual blocks.\n                               This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.\n        upsample_rates: List of upsampling rates.\n                               This value is read from the checkpoint at `config.vocoder.upsample_rates`.\n        upsample_kernel_sizes: List of kernel sizes for the upsampling layers.\n                               This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.\n        resblock_dilation_sizes: List of dilation sizes for the residual blocks.\n                               This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.\n        upsample_initial_channel: Initial number of channels for the upsampling layers.\n                               This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.\n        resblock: Type of residual block to use (\"1\", \"2\", or \"AMP1\").\n                                This value is read from the checkpoint at `config.vocoder.resblock`.\n        output_sampling_rate: Waveform sample rate.\n                               This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.\n        activation: Activation type for BigVGAN v2 (\"snake\" or \"snakebeta\"). Only used when resblock=\"AMP1\".\n        use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).\n        apply_final_activation: Whether to apply the final tanh/clamp activation.\n        use_bias_at_final: Whether to use bias in the final conv layer.\n    \"\"\"\n\n    def __init__(  # noqa: PLR0913\n        self,\n        resblock_kernel_sizes: List[int] | None = [3, 7, 11],\n        upsample_rates: List[int] | None = [6, 5, 2, 2, 2],\n        upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],\n        resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n        upsample_initial_channel: int = 1024,\n        resblock: str = \"1\",\n        output_sampling_rate: int = 24000,\n        activation: str = \"snake\",\n        use_tanh_at_final: bool = True,\n        apply_final_activation: bool = True,\n        use_bias_at_final: bool = True,\n    ) -> None:\n        super().__init__()\n\n        # Mutable default values are not supported as default arguments.\n        if resblock_kernel_sizes is None:\n            resblock_kernel_sizes = [3, 7, 11]\n        if upsample_rates is None:\n            upsample_rates = [6, 5, 2, 2, 2]\n        if upsample_kernel_sizes is None:\n            upsample_kernel_sizes = [16, 15, 8, 4, 4]\n        if resblock_dilation_sizes is None:\n            resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]\n\n        self.output_sampling_rate = output_sampling_rate\n        self.num_kernels = len(resblock_kernel_sizes)\n        self.num_upsamples = len(upsample_rates)\n        self.use_tanh_at_final = use_tanh_at_final\n        self.apply_final_activation = apply_final_activation\n        self.is_amp = resblock == \"AMP1\"\n\n        # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel\n        # bins each), 2 output channels.\n        self.conv_pre = nn.Conv1d(\n            in_channels=128,\n            out_channels=upsample_initial_channel,\n            kernel_size=7,\n            stride=1,\n            padding=3,\n        )\n        resblock_cls = ResBlock1 if resblock == \"1\" else AMPBlock1\n\n        self.ups = nn.ModuleList(\n            nn.ConvTranspose1d(\n                upsample_initial_channel // (2**i),\n                upsample_initial_channel // (2 ** (i + 1)),\n                kernel_size,\n                stride,\n                padding=(kernel_size - stride) // 2,\n            )\n            for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))\n        )\n\n        final_channels = upsample_initial_channel // (2 ** len(upsample_rates))\n        self.resblocks = nn.ModuleList()\n\n        for i in range(len(upsample_rates)):\n            ch = upsample_initial_channel // (2 ** (i + 1))\n            for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):\n                if self.is_amp:\n                    self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))\n                else:\n                    self.resblocks.append(resblock_cls(ch, kernel_size, dilations))\n\n        if self.is_amp:\n            self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))\n        else:\n            self.act_post = nn.LeakyReLU()\n\n        # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).\n        self.conv_post = nn.Conv1d(\n            in_channels=final_channels,\n            out_channels=2,\n            kernel_size=7,\n            stride=1,\n            padding=3,\n            bias=use_bias_at_final,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass of the vocoder.\n        Args:\n            x: Input Mel spectrogram tensor. Can be either:\n               - 3D: (batch_size, time, mel_bins) for mono\n               - 4D: (batch_size, 2, time, mel_bins) for stereo\n        Returns:\n            Audio waveform tensor of shape (batch_size, out_channels, audio_length)\n        \"\"\"\n        x = x.transpose(2, 3)  # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)\n\n        if x.dim() == 4:  # stereo\n            assert x.shape[1] == 2, \"Input must have 2 channels for stereo\"\n            x = einops.rearrange(x, \"b s c t -> b (s c) t\")\n\n        x = self.conv_pre(x)\n\n        for i in range(self.num_upsamples):\n            if not self.is_amp:\n                x = F.leaky_relu(x, LRELU_SLOPE)\n            x = self.ups[i](x)\n            start = i * self.num_kernels\n            end = start + self.num_kernels\n\n            # Evaluate all resblocks with the same input tensor so they can run\n            # independently (and thus in parallel on accelerator hardware) before\n            # aggregating their outputs via mean.\n            block_outputs = torch.stack(\n                [self.resblocks[idx](x) for idx in range(start, end)],\n                dim=0,\n            )\n            x = block_outputs.mean(dim=0)\n\n        x = self.act_post(x)\n        x = self.conv_post(x)\n\n        if self.apply_final_activation:\n            x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)\n\n        return x\n\n\nclass _STFTFn(nn.Module):\n    \"\"\"Implements STFT as a convolution with precomputed DFT x Hann-window bases.\n    The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal\n    Hann window are stored as buffers and loaded from the checkpoint. Using the exact\n    bfloat16 bases from training ensures the mel values fed to the BWE generator are\n    bit-identical to what it was trained on.\n    \"\"\"\n\n    def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:\n        super().__init__()\n        self.hop_length = hop_length\n        self.win_length = win_length\n        n_freqs = filter_length // 2 + 1\n        self.register_buffer(\"forward_basis\", torch.zeros(n_freqs * 2, 1, filter_length))\n        self.register_buffer(\"inverse_basis\", torch.zeros(n_freqs * 2, 1, filter_length))\n\n    def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Compute magnitude and phase spectrogram from a batch of waveforms.\n        Applies causal (left-only) padding of win_length - hop_length samples so that\n        each output frame depends only on past and present input — no lookahead.\n        Args:\n            y: Waveform tensor of shape (B, T).\n        Returns:\n            magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).\n            phase:     Phase spectrogram in radians, shape (B, n_freqs, T_frames).\n        \"\"\"\n        if y.dim() == 2:\n            y = y.unsqueeze(1)  # (B, 1, T)\n        left_pad = max(0, self.win_length - self.hop_length)  # causal: left-only\n        y = F.pad(y, (left_pad, 0))\n        spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)\n        n_freqs = spec.shape[1] // 2\n        real, imag = spec[:, :n_freqs], spec[:, n_freqs:]\n        magnitude = torch.sqrt(real**2 + imag**2)\n        phase = torch.atan2(imag.float(), real.float()).to(real.dtype)\n        return magnitude, phase\n\n\nclass MelSTFT(nn.Module):\n    \"\"\"Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.\n    Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input\n    waveform and projecting the linear magnitude spectrum onto the mel filterbank.\n    The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint\n    (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).\n    \"\"\"\n\n    def __init__(\n        self,\n        filter_length: int,\n        hop_length: int,\n        win_length: int,\n        n_mel_channels: int,\n    ) -> None:\n        super().__init__()\n        self.stft_fn = _STFTFn(filter_length, hop_length, win_length)\n\n        # Initialized to zeros; load_state_dict overwrites with the checkpoint's\n        # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).\n        n_freqs = filter_length // 2 + 1\n        self.register_buffer(\"mel_basis\", torch.zeros(n_mel_channels, n_freqs))\n\n    def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"Compute log-mel spectrogram and auxiliary spectral quantities.\n        Args:\n            y: Waveform tensor of shape (B, T).\n        Returns:\n            log_mel:   Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).\n            magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).\n            phase:     Phase spectrogram in radians, shape (B, n_freqs, T_frames).\n            energy:    Per-frame energy (L2 norm over frequency), shape (B, T_frames).\n        \"\"\"\n        magnitude, phase = self.stft_fn(y)\n        energy = torch.norm(magnitude, dim=1)\n        mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)\n        log_mel = torch.log(torch.clamp(mel, min=1e-5))\n        return log_mel, magnitude, phase, energy\n\n\nclass LTX2VocoderWithBWE(nn.Module):\n    \"\"\"LTX2Vocoder with bandwidth extension (BWE) upsampling.\n    Chains a mel-to-wav vocoder with a BWE module that upsamples the output\n    to a higher sample rate. The BWE computes a mel spectrogram from the\n    vocoder output, runs it through a second generator to predict a residual,\n    and adds it to a sinc-resampled skip connection.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_sampling_rate: int = 16000,\n        output_sampling_rate: int = 48000,\n        hop_length: int = 80,\n    ) -> None:\n        super().__init__()\n        self.vocoder = LTX2Vocoder(\n            resblock_kernel_sizes=[3, 7, 11],\n            upsample_rates=[5, 2, 2, 2, 2, 2],\n            upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],\n            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n            upsample_initial_channel=1536,\n            resblock=\"AMP1\",\n            activation=\"snakebeta\",\n            use_tanh_at_final=False,\n            apply_final_activation=True,\n            use_bias_at_final=False,\n            output_sampling_rate=input_sampling_rate,\n        )\n        self.bwe_generator = LTX2Vocoder(\n            resblock_kernel_sizes=[3, 7, 11],\n            upsample_rates=[6, 5, 2, 2, 2],\n            upsample_kernel_sizes=[12, 11, 4, 4, 4],\n            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n            upsample_initial_channel=512,\n            resblock=\"AMP1\",\n            activation=\"snakebeta\",\n            use_tanh_at_final=False,\n            apply_final_activation=False,\n            use_bias_at_final=False,\n            output_sampling_rate=output_sampling_rate,\n        )\n        \n        self.mel_stft = MelSTFT(\n            filter_length=512,\n            hop_length=hop_length,\n            win_length=512,\n            n_mel_channels=64,\n        )\n        self.input_sampling_rate = input_sampling_rate\n        self.output_sampling_rate = output_sampling_rate\n        self.hop_length = hop_length\n        # Compute the resampler on CPU so the sinc filter is materialized even when\n        # the model is constructed on meta device (SingleGPUModelBuilder pattern).\n        # The filter is not stored in the checkpoint (persistent=False).\n        with torch.device(\"cpu\"):\n            self.resampler = UpSample1d(\n                ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type=\"hann\"\n            )\n\n    @property\n    def conv_pre(self) -> nn.Conv1d:\n        return self.vocoder.conv_pre\n\n    @property\n    def conv_post(self) -> nn.Conv1d:\n        return self.vocoder.conv_post\n\n    def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:\n        \"\"\"Compute log-mel spectrogram from waveform using causal STFT bases.\n        Args:\n            audio: Waveform tensor of shape (B, C, T).\n        Returns:\n            mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).\n        \"\"\"\n        batch, n_channels, _ = audio.shape\n        flat = audio.reshape(batch * n_channels, -1)  # (B*C, T)\n        mel, _, _, _ = self.mel_stft.mel_spectrogram(flat)  # (B*C, n_mels, T_frames)\n        return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])  # (B, C, n_mels, T_frames)\n\n    def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:\n        \"\"\"Run the full vocoder + BWE forward pass.\n        Args:\n            mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo\n                      or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.\n        Returns:\n            Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].\n        \"\"\"\n        x = self.vocoder(mel_spec)\n        _, _, length_low_rate = x.shape\n        output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate\n\n        # Pad to multiple of hop_length for exact mel frame count\n        remainder = length_low_rate % self.hop_length\n        if remainder != 0:\n            x = F.pad(x, (0, self.hop_length - remainder))\n\n        # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)\n        mel = self._compute_mel(x)\n\n        # LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator\n        mel_for_bwe = mel.transpose(2, 3)  # (B, C, T_frames, mel_bins)\n        residual = self.bwe_generator(mel_for_bwe)\n        skip = self.resampler(x)\n        assert residual.shape == skip.shape, f\"residual {residual.shape} != skip {skip.shape}\"\n\n        return torch.clamp(residual + skip, -1, 1)[..., :output_length]\n"
  },
  {
    "path": "diffsynth/models/ltx2_common.py",
    "content": "from dataclasses import dataclass\nfrom typing import NamedTuple, Protocol, Tuple\nimport torch\nfrom torch import nn\nfrom enum import Enum\n\n\nclass VideoPixelShape(NamedTuple):\n    \"\"\"\n    Shape of the tensor representing the video pixel array. Assumes BGR channel format.\n    \"\"\"\n\n    batch: int\n    frames: int\n    height: int\n    width: int\n    fps: float\n\n\nclass SpatioTemporalScaleFactors(NamedTuple):\n    \"\"\"\n    Describes the spatiotemporal downscaling between decoded video space and\n    the corresponding VAE latent grid.\n    \"\"\"\n\n    time: int\n    width: int\n    height: int\n\n    @classmethod\n    def default(cls) -> \"SpatioTemporalScaleFactors\":\n        return cls(time=8, width=32, height=32)\n\n\nVIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()\n\n\nclass VideoLatentShape(NamedTuple):\n    \"\"\"\n    Shape of the tensor representing video in VAE latent space.\n    The latent representation is a 5D tensor with dimensions ordered as\n    (batch, channels, frames, height, width). Spatial and temporal dimensions\n    are downscaled relative to pixel space according to the VAE's scale factors.\n    \"\"\"\n\n    batch: int\n    channels: int\n    frames: int\n    height: int\n    width: int\n\n    def to_torch_shape(self) -> torch.Size:\n        return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])\n\n    @staticmethod\n    def from_torch_shape(shape: torch.Size) -> \"VideoLatentShape\":\n        return VideoLatentShape(\n            batch=shape[0],\n            channels=shape[1],\n            frames=shape[2],\n            height=shape[3],\n            width=shape[4],\n        )\n\n    def mask_shape(self) -> \"VideoLatentShape\":\n        return self._replace(channels=1)\n\n    @staticmethod\n    def from_pixel_shape(\n        shape: VideoPixelShape,\n        latent_channels: int = 128,\n        scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,\n    ) -> \"VideoLatentShape\":\n        frames = (shape.frames - 1) // scale_factors[0] + 1\n        height = shape.height // scale_factors[1]\n        width = shape.width // scale_factors[2]\n\n        return VideoLatentShape(\n            batch=shape.batch,\n            channels=latent_channels,\n            frames=frames,\n            height=height,\n            width=width,\n        )\n\n    def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> \"VideoLatentShape\":\n        return self._replace(\n            channels=3,\n            frames=(self.frames - 1) * scale_factors.time + 1,\n            height=self.height * scale_factors.height,\n            width=self.width * scale_factors.width,\n        )\n\n\nclass AudioLatentShape(NamedTuple):\n    \"\"\"\n    Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).\n    mel_bins is the number of frequency bins from the mel-spectrogram encoding.\n    \"\"\"\n\n    batch: int\n    channels: int\n    frames: int\n    mel_bins: int\n\n    def to_torch_shape(self) -> torch.Size:\n        return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])\n\n    def mask_shape(self) -> \"AudioLatentShape\":\n        return self._replace(channels=1, mel_bins=1)\n\n    @staticmethod\n    def from_torch_shape(shape: torch.Size) -> \"AudioLatentShape\":\n        return AudioLatentShape(\n            batch=shape[0],\n            channels=shape[1],\n            frames=shape[2],\n            mel_bins=shape[3],\n        )\n\n    @staticmethod\n    def from_duration(\n        batch: int,\n        duration: float,\n        channels: int = 8,\n        mel_bins: int = 16,\n        sample_rate: int = 16000,\n        hop_length: int = 160,\n        audio_latent_downsample_factor: int = 4,\n    ) -> \"AudioLatentShape\":\n        latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)\n\n        return AudioLatentShape(\n            batch=batch,\n            channels=channels,\n            frames=round(duration * latents_per_second),\n            mel_bins=mel_bins,\n        )\n\n    @staticmethod\n    def from_video_pixel_shape(\n        shape: VideoPixelShape,\n        channels: int = 8,\n        mel_bins: int = 16,\n        sample_rate: int = 16000,\n        hop_length: int = 160,\n        audio_latent_downsample_factor: int = 4,\n    ) -> \"AudioLatentShape\":\n        return AudioLatentShape.from_duration(\n            batch=shape.batch,\n            duration=float(shape.frames) / float(shape.fps),\n            channels=channels,\n            mel_bins=mel_bins,\n            sample_rate=sample_rate,\n            hop_length=hop_length,\n            audio_latent_downsample_factor=audio_latent_downsample_factor,\n        )\n\n\n@dataclass(frozen=True)\nclass LatentState:\n    \"\"\"\n    State of latents during the diffusion denoising process.\n    Attributes:\n        latent: The current noisy latent tensor being denoised.\n        denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).\n        positions: Positional indices for each latent element, used for positional embeddings.\n        clean_latent: Initial state of the latent before denoising, may include conditioning latents.\n    \"\"\"\n\n    latent: torch.Tensor\n    denoise_mask: torch.Tensor\n    positions: torch.Tensor\n    clean_latent: torch.Tensor\n\n    def clone(self) -> \"LatentState\":\n        return LatentState(\n            latent=self.latent.clone(),\n            denoise_mask=self.denoise_mask.clone(),\n            positions=self.positions.clone(),\n            clean_latent=self.clean_latent.clone(),\n        )\n\n\nclass NormType(Enum):\n    \"\"\"Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).\"\"\"\n\n    GROUP = \"group\"\n    PIXEL = \"pixel\"\n\n\nclass PixelNorm(nn.Module):\n    \"\"\"\n    Per-pixel (per-location) RMS normalization layer.\n    For each element along the chosen dimension, this layer normalizes the tensor\n    by the root-mean-square of its values across that dimension:\n        y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)\n    \"\"\"\n\n    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:\n        \"\"\"\n        Args:\n            dim: Dimension along which to compute the RMS (typically channels).\n            eps: Small constant added for numerical stability.\n        \"\"\"\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply RMS normalization along the configured dimension.\n        \"\"\"\n        # Compute mean of squared values along `dim`, keep dimensions for broadcasting.\n        mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)\n        # Normalize by the root-mean-square (RMS).\n        rms = torch.sqrt(mean_sq + self.eps)\n        return x / rms\n\n\ndef build_normalization_layer(\n    in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP\n) -> nn.Module:\n    \"\"\"\n    Create a normalization layer based on the normalization type.\n    Args:\n        in_channels: Number of input channels\n        num_groups: Number of groups for group normalization\n        normtype: Type of normalization: \"group\" or \"pixel\"\n    Returns:\n        A normalization layer\n    \"\"\"\n    if normtype == NormType.GROUP:\n        return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n    if normtype == NormType.PIXEL:\n        return PixelNorm(dim=1, eps=1e-6)\n    raise ValueError(f\"Invalid normalization type: {normtype}\")\n\n\ndef rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:\n    \"\"\"Root-mean-square (RMS) normalize `x` over its last dimension.\n    Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized\n    shape and forwards `weight` and `eps`.\n    \"\"\"\n    return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)\n\n\n@dataclass(frozen=True)\nclass Modality:\n    \"\"\"\n    Input data for a single modality (video or audio) in the transformer.\n    Bundles the latent tokens, timestep embeddings, positional information,\n    and text conditioning context for processing by the diffusion transformer.\n    Attributes:\n        latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is\n            the batch size, *T* is the total number of tokens (noisy +\n            conditioning), and *D* is the input dimension.\n        timesteps: Per-token timestep embeddings, shape ``(B, T)``.\n        positions: Positional coordinates, shape ``(B, 3, T)`` for video\n            (time, height, width) or ``(B, 1, T)`` for audio.\n        context: Text conditioning embeddings from the prompt encoder.\n        enabled: Whether this modality is active in the current forward pass.\n        context_mask: Optional mask for the text context tokens.\n        attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.\n            Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no\n            attention. ``None`` means unrestricted (full) attention between\n            all tokens. Built incrementally by conditioning items; see\n            :class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.\n    \"\"\"\n\n    latent: (\n        torch.Tensor\n    )  # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension\n    sigma: torch.Tensor  # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.\n    timesteps: torch.Tensor  # Shape: (B, T) where T is the number of timesteps\n    positions: (\n        torch.Tensor\n    )  # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens\n    context: torch.Tensor\n    enabled: bool = True\n    context_mask: torch.Tensor | None = None\n    attention_mask: torch.Tensor | None = None\n\n\ndef to_denoised(\n    sample: torch.Tensor,\n    velocity: torch.Tensor,\n    sigma: float | torch.Tensor,\n    calc_dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    \"\"\"\n    Convert the sample and its denoising velocity to denoised sample.\n    Returns:\n        Denoised sample\n    \"\"\"\n    if isinstance(sigma, torch.Tensor):\n        sigma = sigma.to(calc_dtype)\n    return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)\n\n\n\nclass Patchifier(Protocol):\n    \"\"\"\n    Protocol for patchifiers that convert latent tensors into patches and assemble them back.\n    \"\"\"\n\n    def patchify(\n        self,\n        latents: torch.Tensor,\n    ) -> torch.Tensor:\n        ...\n        \"\"\"\n        Convert latent tensors into flattened patch tokens.\n        Args:\n            latents: Latent tensor to patchify.\n        Returns:\n            Flattened patch tokens tensor.\n        \"\"\"\n\n    def unpatchify(\n        self,\n        latents: torch.Tensor,\n        output_shape: AudioLatentShape | VideoLatentShape,\n    ) -> torch.Tensor:\n        \"\"\"\n        Converts latent tensors between spatio-temporal formats and flattened sequence representations.\n        Args:\n            latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.\n            output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or\n            VideoLatentShape.\n        Returns:\n            Dense latent tensor restored from the flattened representation.\n        \"\"\"\n\n    @property\n    def patch_size(self) -> Tuple[int, int, int]:\n        ...\n        \"\"\"\n        Returns the patch size as a tuple of (temporal, height, width) dimensions\n        \"\"\"\n\n    def get_patch_grid_bounds(\n        self,\n        output_shape: AudioLatentShape | VideoLatentShape,\n        device: torch.device | None = None,\n    ) -> torch.Tensor:\n        ...\n        \"\"\"\n        Compute metadata describing where each latent patch resides within the\n        grid specified by `output_shape`.\n        Args:\n            output_shape: Target grid layout for the patches.\n            device: Target device for the returned tensor.\n        Returns:\n            Tensor containing patch coordinate metadata such as spatial or temporal intervals.\n        \"\"\"\n\n\ndef get_pixel_coords(\n    latent_coords: torch.Tensor,\n    scale_factors: SpatioTemporalScaleFactors,\n    causal_fix: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling\n    each axis (frame/time, height, width) with the corresponding VAE downsampling factors.\n    Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.\n    Args:\n        latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.\n        scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied\n        per axis.\n        causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs\n            that treat frame zero differently still yield non-negative timestamps.\n    \"\"\"\n    # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.\n    broadcast_shape = [1] * latent_coords.ndim\n    broadcast_shape[1] = -1  # axis dimension corresponds to (frame/time, height, width)\n    scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)\n\n    # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.\n    pixel_coords = latent_coords * scale_tensor\n\n    if causal_fix:\n        # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.\n        # Shift and clamp to keep the first-frame timestamps causal and non-negative.\n        pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)\n\n    return pixel_coords\n"
  },
  {
    "path": "diffsynth/models/ltx2_dit.py",
    "content": "import math\nimport functools\nfrom dataclasses import dataclass, replace\nfrom enum import Enum\nfrom typing import Optional, Tuple, Callable\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom .ltx2_common import rms_norm, Modality\nfrom ..core.attention.attention import attention_forward\nfrom ..core import gradient_checkpoint_forward\n\n\ndef get_timestep_embedding(\n    timesteps: torch.Tensor,\n    embedding_dim: int,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 1,\n    scale: float = 1,\n    max_period: int = 10000,\n) -> torch.Tensor:\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.\n    Args\n        timesteps (torch.Tensor):\n            a 1-D Tensor of N indices, one per batch element. These may be fractional.\n        embedding_dim (int):\n            the dimension of the output.\n        flip_sin_to_cos (bool):\n            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)\n        downscale_freq_shift (float):\n            Controls the delta between frequencies between dimensions\n        scale (float):\n            Scaling factor applied to the embeddings.\n        max_period (int):\n            Controls the maximum frequency of the embeddings\n    Returns\n        torch.Tensor: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n\n    half_dim = embedding_dim // 2\n    exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)\n    exponent = exponent / (half_dim - downscale_freq_shift)\n\n    emb = torch.exp(exponent)\n    emb = timesteps[:, None].float() * emb[None, :]\n\n    # scale embeddings\n    emb = scale * emb\n\n    # concat sine and cosine embeddings\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n\n    # flip sine and cosine embeddings\n    if flip_sin_to_cos:\n        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)\n\n    # zero pad\n    if embedding_dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\nclass TimestepEmbedding(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        time_embed_dim: int,\n        out_dim: int | None = None,\n        post_act_fn: str | None = None,\n        cond_proj_dim: int | None = None,\n        sample_proj_bias: bool = True,\n    ):\n        super().__init__()\n\n        self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias)\n\n        if cond_proj_dim is not None:\n            self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False)\n        else:\n            self.cond_proj = None\n\n        self.act = torch.nn.SiLU()\n        time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim\n\n        self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)\n\n        if post_act_fn is None:\n            self.post_act = None\n\n    def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:\n        if condition is not None:\n            sample = sample + self.cond_proj(condition)\n        sample = self.linear_1(sample)\n\n        if self.act is not None:\n            sample = self.act(sample)\n\n        sample = self.linear_2(sample)\n\n        if self.post_act is not None:\n            sample = self.post_act(sample)\n        return sample\n\n\nclass Timesteps(torch.nn.Module):\n    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):\n        super().__init__()\n        self.num_channels = num_channels\n        self.flip_sin_to_cos = flip_sin_to_cos\n        self.downscale_freq_shift = downscale_freq_shift\n        self.scale = scale\n\n    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:\n        t_emb = get_timestep_embedding(\n            timesteps,\n            self.num_channels,\n            flip_sin_to_cos=self.flip_sin_to_cos,\n            downscale_freq_shift=self.downscale_freq_shift,\n            scale=self.scale,\n        )\n        return t_emb\n\n\nclass PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):\n    \"\"\"\n    For PixArt-Alpha.\n    Reference:\n    https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29\n    \"\"\"\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        size_emb_dim: int,\n    ):\n        super().__init__()\n\n        self.outdim = size_emb_dim\n        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)\n        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)\n\n    def forward(\n        self,\n        timestep: torch.Tensor,\n        hidden_dtype: torch.dtype,\n    ) -> torch.Tensor:\n        timesteps_proj = self.time_proj(timestep)\n        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)\n        return timesteps_emb\n\n\nclass PerturbationType(Enum):\n    \"\"\"Types of attention perturbations for STG (Spatio-Temporal Guidance).\"\"\"\n\n    SKIP_A2V_CROSS_ATTN = \"skip_a2v_cross_attn\"\n    SKIP_V2A_CROSS_ATTN = \"skip_v2a_cross_attn\"\n    SKIP_VIDEO_SELF_ATTN = \"skip_video_self_attn\"\n    SKIP_AUDIO_SELF_ATTN = \"skip_audio_self_attn\"\n\n\n@dataclass(frozen=True)\nclass Perturbation:\n    \"\"\"A single perturbation specifying which attention type to skip and in which blocks.\"\"\"\n\n    type: PerturbationType\n    blocks: list[int] | None  # None means all blocks\n\n    def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:\n        if self.type != perturbation_type:\n            return False\n\n        if self.blocks is None:\n            return True\n\n        return block in self.blocks\n\n\n@dataclass(frozen=True)\nclass PerturbationConfig:\n    \"\"\"Configuration holding a list of perturbations for a single sample.\"\"\"\n\n    perturbations: list[Perturbation] | None\n\n    def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:\n        if self.perturbations is None:\n            return False\n\n        return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)\n\n    @staticmethod\n    def empty() -> \"PerturbationConfig\":\n        return PerturbationConfig([])\n\n\n@dataclass(frozen=True)\nclass BatchedPerturbationConfig:\n    \"\"\"Perturbation configurations for a batch, with utilities for generating attention masks.\"\"\"\n\n    perturbations: list[PerturbationConfig]\n\n    def mask(\n        self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype\n    ) -> torch.Tensor:\n        mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)\n        for batch_idx, perturbation in enumerate(self.perturbations):\n            if perturbation.is_perturbed(perturbation_type, block):\n                mask[batch_idx] = 0\n\n        return mask\n\n    def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:\n        mask = self.mask(perturbation_type, block, values.device, values.dtype)\n        return mask.view(mask.numel(), *([1] * len(values.shape[1:])))\n\n    def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:\n        return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)\n\n    def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:\n        return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)\n\n    @staticmethod\n    def empty(batch_size: int) -> \"BatchedPerturbationConfig\":\n        return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])\n\n\n\nADALN_NUM_BASE_PARAMS = 6\n# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.\nADALN_NUM_CROSS_ATTN_PARAMS = 3\n\n\ndef adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:\n    \"\"\"Total number of AdaLN parameters per block.\"\"\"\n    return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)\n\n\nclass AdaLayerNormSingle(torch.nn.Module):\n    r\"\"\"\n    Norm layer adaptive layer norm single (adaLN-single).\n    As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).\n    Parameters:\n        embedding_dim (`int`): The size of each embedding vector.\n        use_additional_conditions (`bool`): To use additional conditions for normalization or not.\n    \"\"\"\n\n    def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):\n        super().__init__()\n\n        self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(\n            embedding_dim,\n            size_emb_dim=embedding_dim // 3,\n        )\n\n        self.silu = torch.nn.SiLU()\n        self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)\n\n    def forward(\n        self,\n        timestep: torch.Tensor,\n        hidden_dtype: Optional[torch.dtype] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)\n        return self.linear(self.silu(embedded_timestep)), embedded_timestep\n\n\nclass LTXRopeType(Enum):\n    INTERLEAVED = \"interleaved\"\n    SPLIT = \"split\"\n\n\ndef apply_rotary_emb(\n    input_tensor: torch.Tensor,\n    freqs_cis: Tuple[torch.Tensor, torch.Tensor],\n    rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,\n) -> torch.Tensor:\n    if rope_type == LTXRopeType.INTERLEAVED:\n        return apply_interleaved_rotary_emb(input_tensor, *freqs_cis)\n    elif rope_type == LTXRopeType.SPLIT:\n        return apply_split_rotary_emb(input_tensor, *freqs_cis)\n    else:\n        raise ValueError(f\"Invalid rope type: {rope_type}\")\n\n\n\ndef apply_interleaved_rotary_emb(\n    input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor\n) -> torch.Tensor:\n    t_dup = rearrange(input_tensor, \"... (d r) -> ... d r\", r=2)\n    t1, t2 = t_dup.unbind(dim=-1)\n    t_dup = torch.stack((-t2, t1), dim=-1)\n    input_tensor_rot = rearrange(t_dup, \"... d r -> ... (d r)\")\n\n    out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs\n\n    return out\n\n\ndef apply_split_rotary_emb(\n    input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor\n) -> torch.Tensor:\n    needs_reshape = False\n    if input_tensor.ndim != 4 and cos_freqs.ndim == 4:\n        b, h, t, _ = cos_freqs.shape\n        input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2)\n        needs_reshape = True\n\n    split_input = rearrange(input_tensor, \"... (d r) -> ... d r\", d=2)\n    first_half_input = split_input[..., :1, :]\n    second_half_input = split_input[..., 1:, :]\n\n    output = split_input * cos_freqs.unsqueeze(-2)\n    first_half_output = output[..., :1, :]\n    second_half_output = output[..., 1:, :]\n\n    first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input)\n    second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input)\n\n    output = rearrange(output, \"... d r -> ... (d r)\")\n    if needs_reshape:\n        output = output.swapaxes(1, 2).reshape(b, t, -1)\n\n    return output\n\n\n@functools.lru_cache(maxsize=5)\ndef generate_freq_grid_np(\n    positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int\n) -> torch.Tensor:\n    theta = positional_embedding_theta\n    start = 1\n    end = theta\n\n    n_elem = 2 * positional_embedding_max_pos_count\n    pow_indices = np.power(\n        theta,\n        np.linspace(\n            np.log(start) / np.log(theta),\n            np.log(end) / np.log(theta),\n            inner_dim // n_elem,\n            dtype=np.float64,\n        ),\n    )\n    return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)\n\n\n@functools.lru_cache(maxsize=5)\ndef generate_freq_grid_pytorch(\n    positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int\n) -> torch.Tensor:\n    theta = positional_embedding_theta\n    start = 1\n    end = theta\n    n_elem = 2 * positional_embedding_max_pos_count\n\n    indices = theta ** (\n        torch.linspace(\n            math.log(start, theta),\n            math.log(end, theta),\n            inner_dim // n_elem,\n            dtype=torch.float32,\n        )\n    )\n    indices = indices.to(dtype=torch.float32)\n\n    indices = indices * math.pi / 2\n\n    return indices\n\n\ndef get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor:\n    n_pos_dims = indices_grid.shape[1]\n    assert n_pos_dims == len(max_pos), (\n        f\"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})\"\n    )\n    fractional_positions = torch.stack(\n        [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],\n        dim=-1,\n    )\n    return fractional_positions\n\n\ndef generate_freqs(\n    indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool\n) -> torch.Tensor:\n    if use_middle_indices_grid:\n        assert len(indices_grid.shape) == 4\n        assert indices_grid.shape[-1] == 2\n        indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]\n        indices_grid = (indices_grid_start + indices_grid_end) / 2.0\n    elif len(indices_grid.shape) == 4:\n        indices_grid = indices_grid[..., 0]\n\n    fractional_positions = get_fractional_positions(indices_grid, max_pos)\n    indices = indices.to(device=fractional_positions.device)\n\n    freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)\n    return freqs\n\n\ndef split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]:\n    cos_freq = freqs.cos()\n    sin_freq = freqs.sin()\n\n    if pad_size != 0:\n        cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])\n        sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])\n\n        cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)\n        sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)\n\n    # Reshape freqs to be compatible with multi-head attention\n    b = cos_freq.shape[0]\n    t = cos_freq.shape[1]\n\n    cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)\n    sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)\n\n    cos_freq = torch.swapaxes(cos_freq, 1, 2)  # (B,H,T,D//2)\n    sin_freq = torch.swapaxes(sin_freq, 1, 2)  # (B,H,T,D//2)\n    return cos_freq, sin_freq\n\n\ndef interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]:\n    cos_freq = freqs.cos().repeat_interleave(2, dim=-1)\n    sin_freq = freqs.sin().repeat_interleave(2, dim=-1)\n    if pad_size != 0:\n        cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])\n        sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size])\n        cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)\n        sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)\n    return cos_freq, sin_freq\n\n\ndef precompute_freqs_cis(\n    indices_grid: torch.Tensor,\n    dim: int,\n    out_dtype: torch.dtype,\n    theta: float = 10000.0,\n    max_pos: list[int] | None = None,\n    use_middle_indices_grid: bool = False,\n    num_attention_heads: int = 32,\n    rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,\n    freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    if max_pos is None:\n        max_pos = [20, 2048, 2048]\n\n    indices = freq_grid_generator(theta, indices_grid.shape[1], dim)\n    freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)\n\n    if rope_type == LTXRopeType.SPLIT:\n        expected_freqs = dim // 2\n        current_freqs = freqs.shape[-1]\n        pad_size = expected_freqs - current_freqs\n        cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)\n    else:\n        # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only\n        n_elem = 2 * indices_grid.shape[1]\n        cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)\n    return cos_freq.to(out_dtype), sin_freq.to(out_dtype)\n\n\nclass Attention(torch.nn.Module):\n    def __init__(\n        self,\n        query_dim: int,\n        context_dim: int | None = None,\n        heads: int = 8,\n        dim_head: int = 64,\n        norm_eps: float = 1e-6,\n        rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,\n        apply_gated_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.rope_type = rope_type\n\n        inner_dim = dim_head * heads\n        context_dim = query_dim if context_dim is None else context_dim\n\n        self.heads = heads\n        self.dim_head = dim_head\n\n        self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)\n        self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)\n\n        self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)\n        self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)\n        self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)\n\n        # Optional per-head gating\n        if apply_gated_attention:\n            self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)\n        else:\n            self.to_gate_logits = None\n\n        self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: torch.Tensor | None = None,\n        mask: torch.Tensor | None = None,\n        pe: torch.Tensor | None = None,\n        k_pe: torch.Tensor | None = None,\n        perturbation_mask: torch.Tensor | None = None,\n        all_perturbed: bool = False,\n    ) -> torch.Tensor:\n        q = self.to_q(x)\n        context = x if context is None else context\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        if pe is not None:\n            q = apply_rotary_emb(q, pe, self.rope_type)\n            k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)\n\n        # Reshape for attention_forward using unflatten\n        q = q.unflatten(-1, (self.heads, self.dim_head))\n        k = k.unflatten(-1, (self.heads, self.dim_head))\n        v = v.unflatten(-1, (self.heads, self.dim_head))\n\n        out = attention_forward(\n            q=q,\n            k=k,\n            v=v,\n            q_pattern=\"b s n d\",\n            k_pattern=\"b s n d\",\n            v_pattern=\"b s n d\",\n            out_pattern=\"b s n d\",\n            attn_mask=mask\n        )\n\n        # Reshape back to original format\n        out = out.flatten(2, 3)\n\n        # Apply per-head gating if enabled\n        if self.to_gate_logits is not None:\n            gate_logits = self.to_gate_logits(x)  # (B, T, H)\n            b, t, _ = out.shape\n            # Reshape to (B, T, H, D) for per-head gating\n            out = out.view(b, t, self.heads, self.dim_head)\n            # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)\n            gates = 2.0 * torch.sigmoid(gate_logits)  # (B, T, H)\n            out = out * gates.unsqueeze(-1)  # (B, T, H, D) * (B, T, H, 1)\n            # Reshape back to (B, T, H*D)\n            out = out.view(b, t, self.heads * self.dim_head)\n\n        return self.to_out(out)\n\n\nclass PixArtAlphaTextProjection(torch.nn.Module):\n    \"\"\"\n    Projects caption embeddings. Also handles dropout for classifier-free guidance.\n    Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py\n    \"\"\"\n\n    def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = \"gelu_tanh\"):\n        super().__init__()\n        if out_features is None:\n            out_features = hidden_size\n        self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)\n        if act_fn == \"gelu_tanh\":\n            self.act_1 = torch.nn.GELU(approximate=\"tanh\")\n        elif act_fn == \"silu\":\n            self.act_1 = torch.nn.SiLU()\n        else:\n            raise ValueError(f\"Unknown activation function: {act_fn}\")\n        self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)\n\n    def forward(self, caption: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.linear_1(caption)\n        hidden_states = self.act_1(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n@dataclass(frozen=True)\nclass TransformerArgs:\n    x: torch.Tensor\n    context: torch.Tensor\n    context_mask: torch.Tensor\n    timesteps: torch.Tensor\n    embedded_timestep: torch.Tensor\n    positional_embeddings: torch.Tensor\n    cross_positional_embeddings: torch.Tensor | None\n    cross_scale_shift_timestep: torch.Tensor | None\n    cross_gate_timestep: torch.Tensor | None\n    enabled: bool\n    prompt_timestep: torch.Tensor | None = None\n    self_attention_mask: torch.Tensor | None = (\n        None  # Additive log-space self-attention bias (B, 1, T, T), None = full attention\n    )\n\n\nclass TransformerArgsPreprocessor:\n    def __init__(  # noqa: PLR0913\n        self,\n        patchify_proj: torch.nn.Linear,\n        adaln: AdaLayerNormSingle,\n        inner_dim: int,\n        max_pos: list[int],\n        num_attention_heads: int,\n        use_middle_indices_grid: bool,\n        timestep_scale_multiplier: int,\n        double_precision_rope: bool,\n        positional_embedding_theta: float,\n        rope_type: LTXRopeType,\n        caption_projection: torch.nn.Module | None = None,\n        prompt_adaln: AdaLayerNormSingle | None = None,\n    ) -> None:\n        self.patchify_proj = patchify_proj\n        self.adaln = adaln\n        self.inner_dim = inner_dim\n        self.max_pos = max_pos\n        self.num_attention_heads = num_attention_heads\n        self.use_middle_indices_grid = use_middle_indices_grid\n        self.timestep_scale_multiplier = timestep_scale_multiplier\n        self.double_precision_rope = double_precision_rope\n        self.positional_embedding_theta = positional_embedding_theta\n        self.rope_type = rope_type\n        self.caption_projection = caption_projection\n        self.prompt_adaln = prompt_adaln\n\n    def _prepare_timestep(\n        self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Prepare timestep embeddings.\"\"\"\n        timestep_scaled = timestep * self.timestep_scale_multiplier\n        timestep, embedded_timestep = adaln(\n            timestep_scaled.flatten(),\n            hidden_dtype=hidden_dtype,\n        )\n        # Second dimension is 1 or number of tokens (if timestep_per_token)\n        timestep = timestep.view(batch_size, -1, timestep.shape[-1])\n        embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])\n        return timestep, embedded_timestep\n\n    def _prepare_context(\n        self,\n        context: torch.Tensor,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Prepare context for transformer blocks.\"\"\"\n        if self.caption_projection is not None:\n            context = self.caption_projection(context)\n        batch_size = x.shape[0]\n        return context.view(batch_size, -1, x.shape[-1])\n\n    def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:\n        \"\"\"Prepare attention mask.\"\"\"\n        if attention_mask is None or torch.is_floating_point(attention_mask):\n            return attention_mask\n\n        return (attention_mask - 1).to(x_dtype).reshape(\n            (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])\n        ) * torch.finfo(x_dtype).max\n\n    def _prepare_self_attention_mask(\n        self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype\n    ) -> torch.Tensor | None:\n        \"\"\"Prepare self-attention mask by converting [0,1] values to additive log-space bias.\n        Input shape: (B, T, T) with values in [0, 1].\n        Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value\n        for masked positions.\n        Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum\n        representable value). Strictly positive entries are converted via log-space for\n        smooth attenuation, with small values clamped for numerical stability.\n        Returns None if input is None (no masking).\n        \"\"\"\n        if attention_mask is None:\n            return None\n\n        # Convert [0, 1] attention mask to additive log-space bias:\n        #   1.0 -> log(1.0) = 0.0  (no bias, full attention)\n        #   0.0 -> finfo.min        (fully masked)\n        finfo = torch.finfo(x_dtype)\n        eps = finfo.tiny\n\n        bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)\n        positive = attention_mask > 0\n        if positive.any():\n            bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)\n\n        return bias.unsqueeze(1)  # (B, 1, T, T) for head broadcast\n\n    def _prepare_positional_embeddings(\n        self,\n        positions: torch.Tensor,\n        inner_dim: int,\n        max_pos: list[int],\n        use_middle_indices_grid: bool,\n        num_attention_heads: int,\n        x_dtype: torch.dtype,\n    ) -> torch.Tensor:\n        \"\"\"Prepare positional embeddings.\"\"\"\n        freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch\n        pe = precompute_freqs_cis(\n            positions,\n            dim=inner_dim,\n            out_dtype=x_dtype,\n            theta=self.positional_embedding_theta,\n            max_pos=max_pos,\n            use_middle_indices_grid=use_middle_indices_grid,\n            num_attention_heads=num_attention_heads,\n            rope_type=self.rope_type,\n            freq_grid_generator=freq_grid_generator,\n        )\n        return pe\n\n    def prepare(\n        self,\n        modality: Modality,\n        cross_modality: Modality | None = None,  # noqa: ARG002\n    ) -> TransformerArgs:\n        x = self.patchify_proj(modality.latent)\n        batch_size = x.shape[0]\n        timestep, embedded_timestep = self._prepare_timestep(\n            modality.timesteps, self.adaln, batch_size, modality.latent.dtype\n        )\n        prompt_timestep = None\n        if self.prompt_adaln is not None:\n            prompt_timestep, _ = self._prepare_timestep(\n                modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype\n            )\n        context = self._prepare_context(modality.context, x)\n        attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)\n        pe = self._prepare_positional_embeddings(\n            positions=modality.positions,\n            inner_dim=self.inner_dim,\n            max_pos=self.max_pos,\n            use_middle_indices_grid=self.use_middle_indices_grid,\n            num_attention_heads=self.num_attention_heads,\n            x_dtype=modality.latent.dtype,\n        )\n        self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)\n        return TransformerArgs(\n            x=x,\n            context=context,\n            context_mask=attention_mask,\n            timesteps=timestep,\n            embedded_timestep=embedded_timestep,\n            positional_embeddings=pe,\n            cross_positional_embeddings=None,\n            cross_scale_shift_timestep=None,\n            cross_gate_timestep=None,\n            enabled=modality.enabled,\n            prompt_timestep=prompt_timestep,\n            self_attention_mask=self_attention_mask,\n        )\n\n\nclass MultiModalTransformerArgsPreprocessor:\n    def __init__(  # noqa: PLR0913\n        self,\n        patchify_proj: torch.nn.Linear,\n        adaln: AdaLayerNormSingle,\n        cross_scale_shift_adaln: AdaLayerNormSingle,\n        cross_gate_adaln: AdaLayerNormSingle,\n        inner_dim: int,\n        max_pos: list[int],\n        num_attention_heads: int,\n        cross_pe_max_pos: int,\n        use_middle_indices_grid: bool,\n        audio_cross_attention_dim: int,\n        timestep_scale_multiplier: int,\n        double_precision_rope: bool,\n        positional_embedding_theta: float,\n        rope_type: LTXRopeType,\n        av_ca_timestep_scale_multiplier: int,\n        caption_projection: torch.nn.Module | None = None,\n        prompt_adaln: AdaLayerNormSingle | None = None,\n    ) -> None:\n        self.simple_preprocessor = TransformerArgsPreprocessor(\n            patchify_proj=patchify_proj,\n            adaln=adaln,\n            inner_dim=inner_dim,\n            max_pos=max_pos,\n            num_attention_heads=num_attention_heads,\n            use_middle_indices_grid=use_middle_indices_grid,\n            timestep_scale_multiplier=timestep_scale_multiplier,\n            double_precision_rope=double_precision_rope,\n            positional_embedding_theta=positional_embedding_theta,\n            rope_type=rope_type,\n            caption_projection=caption_projection,\n            prompt_adaln=prompt_adaln,\n        )\n        self.cross_scale_shift_adaln = cross_scale_shift_adaln\n        self.cross_gate_adaln = cross_gate_adaln\n        self.cross_pe_max_pos = cross_pe_max_pos\n        self.audio_cross_attention_dim = audio_cross_attention_dim\n        self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier\n\n    def prepare(\n        self,\n        modality: Modality,\n        cross_modality: Modality | None = None,\n    ) -> TransformerArgs:\n        transformer_args = self.simple_preprocessor.prepare(modality)\n        if cross_modality is None:\n            return transformer_args\n\n        if cross_modality.sigma.numel() > 1:\n            if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:\n                raise ValueError(\"Cross modality sigma must have the same batch size as the modality\")\n            if cross_modality.sigma.ndim != 1:\n                raise ValueError(\"Cross modality sigma must be a 1D tensor\")\n\n        cross_timestep = cross_modality.sigma.view(\n            modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])\n        )\n\n        cross_pe = self.simple_preprocessor._prepare_positional_embeddings(\n            positions=modality.positions[:, 0:1, :],\n            inner_dim=self.audio_cross_attention_dim,\n            max_pos=[self.cross_pe_max_pos],\n            use_middle_indices_grid=True,\n            num_attention_heads=self.simple_preprocessor.num_attention_heads,\n            x_dtype=modality.latent.dtype,\n        )\n\n        cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(\n            timestep=cross_timestep,\n            timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,\n            batch_size=transformer_args.x.shape[0],\n            hidden_dtype=modality.latent.dtype,\n        )\n\n        return replace(\n            transformer_args,\n            cross_positional_embeddings=cross_pe,\n            cross_scale_shift_timestep=cross_scale_shift_timestep,\n            cross_gate_timestep=cross_gate_timestep,\n        )\n\n    def _prepare_cross_attention_timestep(\n        self,\n        timestep: torch.Tensor | None,\n        timestep_scale_multiplier: int,\n        batch_size: int,\n        hidden_dtype: torch.dtype,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Prepare cross attention timestep embeddings.\"\"\"\n        timestep = timestep * timestep_scale_multiplier\n\n        av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier\n\n        scale_shift_timestep, _ = self.cross_scale_shift_adaln(\n            timestep.flatten(),\n            hidden_dtype=hidden_dtype,\n        )\n        scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1])\n        gate_noise_timestep, _ = self.cross_gate_adaln(\n            timestep.flatten() * av_ca_factor,\n            hidden_dtype=hidden_dtype,\n        )\n        gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1])\n\n        return scale_shift_timestep, gate_noise_timestep\n\n\n@dataclass\nclass TransformerConfig:\n    dim: int\n    heads: int\n    d_head: int\n    context_dim: int\n    apply_gated_attention: bool = False\n    cross_attention_adaln: bool = False\n\n\nclass BasicAVTransformerBlock(torch.nn.Module):\n    def __init__(\n        self,\n        idx: int,\n        video: TransformerConfig | None = None,\n        audio: TransformerConfig | None = None,\n        rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,\n        norm_eps: float = 1e-6,\n    ):\n        super().__init__()\n\n        self.idx = idx\n        if video is not None:\n            self.attn1 = Attention(\n                query_dim=video.dim,\n                heads=video.heads,\n                dim_head=video.d_head,\n                context_dim=None,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=video.apply_gated_attention,\n            )\n            self.attn2 = Attention(\n                query_dim=video.dim,\n                context_dim=video.context_dim,\n                heads=video.heads,\n                dim_head=video.d_head,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=video.apply_gated_attention,\n            )\n            self.ff = FeedForward(video.dim, dim_out=video.dim)\n            video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)\n            self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))\n\n        if audio is not None:\n            self.audio_attn1 = Attention(\n                query_dim=audio.dim,\n                heads=audio.heads,\n                dim_head=audio.d_head,\n                context_dim=None,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=audio.apply_gated_attention,\n            )\n            self.audio_attn2 = Attention(\n                query_dim=audio.dim,\n                context_dim=audio.context_dim,\n                heads=audio.heads,\n                dim_head=audio.d_head,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=audio.apply_gated_attention,\n            )\n            self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)\n            audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)\n            self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))\n\n        if audio is not None and video is not None:\n            # Q: Video, K,V: Audio\n            self.audio_to_video_attn = Attention(\n                query_dim=video.dim,\n                context_dim=audio.dim,\n                heads=audio.heads,\n                dim_head=audio.d_head,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=video.apply_gated_attention,\n            )\n\n            # Q: Audio, K,V: Video\n            self.video_to_audio_attn = Attention(\n                query_dim=audio.dim,\n                context_dim=video.dim,\n                heads=audio.heads,\n                dim_head=audio.d_head,\n                rope_type=rope_type,\n                norm_eps=norm_eps,\n                apply_gated_attention=audio.apply_gated_attention,\n            )\n\n            self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))\n            self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))\n\n        self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (\n            audio is not None and audio.cross_attention_adaln\n        )\n\n        if self.cross_attention_adaln and video is not None:\n            self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))\n        if self.cross_attention_adaln and audio is not None:\n            self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))\n\n        self.norm_eps = norm_eps\n\n    def get_ada_values(\n        self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice\n    ) -> tuple[torch.Tensor, ...]:\n        num_ada_params = scale_shift_table.shape[0]\n\n        ada_values = (\n            scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)\n            + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]\n        ).unbind(dim=2)\n        return ada_values\n\n    def get_av_ca_ada_values(\n        self,\n        scale_shift_table: torch.Tensor,\n        batch_size: int,\n        scale_shift_timestep: torch.Tensor,\n        gate_timestep: torch.Tensor,\n        scale_shift_indices: slice,\n        num_scale_shift_values: int = 4,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        scale_shift_ada_values = self.get_ada_values(\n            scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices\n        )\n        gate_ada_values = self.get_ada_values(\n            scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)\n        )\n\n        scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)\n        (gate,) = (t.squeeze(2) for t in gate_ada_values)\n\n        return scale, shift, gate\n\n    def _apply_text_cross_attention(\n        self,\n        x: torch.Tensor,\n        context: torch.Tensor,\n        attn: Attention,\n        scale_shift_table: torch.Tensor,\n        prompt_scale_shift_table: torch.Tensor | None,\n        timestep: torch.Tensor,\n        prompt_timestep: torch.Tensor | None,\n        context_mask: torch.Tensor | None,\n        cross_attention_adaln: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Apply text cross-attention, with optional AdaLN modulation.\"\"\"\n        if cross_attention_adaln:\n            shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))\n            return apply_cross_attention_adaln(\n                x,\n                context,\n                attn,\n                shift_q,\n                scale_q,\n                gate,\n                prompt_scale_shift_table,\n                prompt_timestep,\n                context_mask,\n                self.norm_eps,\n            )\n        return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)\n\n    def forward(  # noqa: PLR0915\n        self,\n        video: TransformerArgs | None,\n        audio: TransformerArgs | None,\n        perturbations: BatchedPerturbationConfig | None = None,\n    ) -> tuple[TransformerArgs | None, TransformerArgs | None]:\n        if video is None and audio is None:\n            raise ValueError(\"At least one of video or audio must be provided\")\n\n        batch_size = (video or audio).x.shape[0]\n\n        if perturbations is None:\n            perturbations = BatchedPerturbationConfig.empty(batch_size)\n\n        vx = video.x if video is not None else None\n        ax = audio.x if audio is not None else None\n\n        run_vx = video is not None and video.enabled and vx.numel() > 0\n        run_ax = audio is not None and audio.enabled and ax.numel() > 0\n\n        run_a2v = run_vx and (audio is not None and ax.numel() > 0)\n        run_v2a = run_ax and (video is not None and vx.numel() > 0)\n\n        if run_vx:\n            vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(\n                self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)\n            )\n            norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa\n            del vshift_msa, vscale_msa\n\n            all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)\n            none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)\n            v_mask = (\n                perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)\n                if not all_perturbed and not none_perturbed\n                else None\n            )\n            vx = (\n                vx\n                + self.attn1(\n                    norm_vx,\n                    pe=video.positional_embeddings,\n                    mask=video.self_attention_mask,\n                    perturbation_mask=v_mask,\n                    all_perturbed=all_perturbed,\n                )\n                * vgate_msa\n            )\n            del vgate_msa, norm_vx, v_mask\n            vx = vx + self._apply_text_cross_attention(\n                vx,\n                video.context,\n                self.attn2,\n                self.scale_shift_table,\n                getattr(self, \"prompt_scale_shift_table\", None),\n                video.timesteps,\n                video.prompt_timestep,\n                video.context_mask,\n                cross_attention_adaln=self.cross_attention_adaln,\n            )\n\n        if run_ax:\n            ashift_msa, ascale_msa, agate_msa = self.get_ada_values(\n                self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)\n            )\n\n            norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa\n            del ashift_msa, ascale_msa\n            all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)\n            none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)\n            a_mask = (\n                perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)\n                if not all_perturbed and not none_perturbed\n                else None\n            )\n            ax = (\n                ax\n                + self.audio_attn1(\n                    norm_ax,\n                    pe=audio.positional_embeddings,\n                    mask=audio.self_attention_mask,\n                    perturbation_mask=a_mask,\n                    all_perturbed=all_perturbed,\n                )\n                * agate_msa\n            )\n            del agate_msa, norm_ax, a_mask\n            ax = ax + self._apply_text_cross_attention(\n                ax,\n                audio.context,\n                self.audio_attn2,\n                self.audio_scale_shift_table,\n                getattr(self, \"audio_prompt_scale_shift_table\", None),\n                audio.timesteps,\n                audio.prompt_timestep,\n                audio.context_mask,\n                cross_attention_adaln=self.cross_attention_adaln,\n            )\n\n        # Audio - Video cross attention.\n        if run_a2v or run_v2a:\n            vx_norm3 = rms_norm(vx, eps=self.norm_eps)\n            ax_norm3 = rms_norm(ax, eps=self.norm_eps)\n\n            if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):\n                scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(\n                    self.scale_shift_table_a2v_ca_video,\n                    vx.shape[0],\n                    video.cross_scale_shift_timestep,\n                    video.cross_gate_timestep,\n                    slice(0, 2),\n                )\n                vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v\n                del scale_ca_video_a2v, shift_ca_video_a2v\n\n                scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(\n                    self.scale_shift_table_a2v_ca_audio,\n                    ax.shape[0],\n                    audio.cross_scale_shift_timestep,\n                    audio.cross_gate_timestep,\n                    slice(0, 2),\n                )\n                ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v\n                del scale_ca_audio_a2v, shift_ca_audio_a2v\n                a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)\n                vx = vx + (\n                    self.audio_to_video_attn(\n                        vx_scaled,\n                        context=ax_scaled,\n                        pe=video.cross_positional_embeddings,\n                        k_pe=audio.cross_positional_embeddings,\n                    )\n                    * gate_out_a2v\n                    * a2v_mask\n                )\n                del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled\n\n            if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):\n                scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(\n                    self.scale_shift_table_a2v_ca_audio,\n                    ax.shape[0],\n                    audio.cross_scale_shift_timestep,\n                    audio.cross_gate_timestep,\n                    slice(2, 4),\n                )\n                ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a\n                del scale_ca_audio_v2a, shift_ca_audio_v2a\n                scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(\n                    self.scale_shift_table_a2v_ca_video,\n                    vx.shape[0],\n                    video.cross_scale_shift_timestep,\n                    video.cross_gate_timestep,\n                    slice(2, 4),\n                )\n                vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a\n                del scale_ca_video_v2a, shift_ca_video_v2a\n                v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)\n                ax = ax + (\n                    self.video_to_audio_attn(\n                        ax_scaled,\n                        context=vx_scaled,\n                        pe=audio.cross_positional_embeddings,\n                        k_pe=video.cross_positional_embeddings,\n                    )\n                    * gate_out_v2a\n                    * v2a_mask\n                )\n                del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled\n\n            del vx_norm3, ax_norm3\n\n        if run_vx:\n            vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(\n                self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)\n            )\n            vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp\n            vx = vx + self.ff(vx_scaled) * vgate_mlp\n\n            del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled\n\n        if run_ax:\n            ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(\n                self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)\n            )\n            ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp\n            ax = ax + self.audio_ff(ax_scaled) * agate_mlp\n\n            del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled\n\n        return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None\n\n\ndef apply_cross_attention_adaln(\n    x: torch.Tensor,\n    context: torch.Tensor,\n    attn: Attention,\n    q_shift: torch.Tensor,\n    q_scale: torch.Tensor,\n    q_gate: torch.Tensor,\n    prompt_scale_shift_table: torch.Tensor,\n    prompt_timestep: torch.Tensor,\n    context_mask: torch.Tensor | None = None,\n    norm_eps: float = 1e-6,\n) -> torch.Tensor:\n    batch_size = x.shape[0]\n    shift_kv, scale_kv = (\n        prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)\n        + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)\n    ).unbind(dim=2)\n    attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift\n    encoder_hidden_states = context * (1 + scale_kv) + shift_kv\n    return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate\n\n\nclass GELUApprox(torch.nn.Module):\n    def __init__(self, dim_in: int, dim_out: int) -> None:\n        super().__init__()\n        self.proj = torch.nn.Linear(dim_in, dim_out)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.gelu(self.proj(x), approximate=\"tanh\")\n\n\nclass FeedForward(torch.nn.Module):\n    def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:\n        super().__init__()\n        inner_dim = int(dim * mult)\n        project_in = GELUApprox(dim, inner_dim)\n\n        self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.net(x)\n\n\nclass LTXModelType(Enum):\n    AudioVideo = \"ltx av model\"\n    VideoOnly = \"ltx video only model\"\n    AudioOnly = \"ltx audio only model\"\n\n    def is_video_enabled(self) -> bool:\n        return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)\n\n    def is_audio_enabled(self) -> bool:\n        return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)\n\n\nclass LTXModel(torch.nn.Module):\n    \"\"\"\n    LTX model transformer implementation.\n    This class implements the transformer blocks for the LTX model.\n    \"\"\"\n\n    def __init__(  # noqa: PLR0913\n        self,\n        *,\n        model_type: LTXModelType = LTXModelType.AudioVideo,\n        num_attention_heads: int = 32,\n        attention_head_dim: int = 128,\n        in_channels: int = 128,\n        out_channels: int = 128,\n        num_layers: int = 48,\n        cross_attention_dim: int = 4096,\n        norm_eps: float = 1e-06,\n        caption_channels: int = 3840,\n        positional_embedding_theta: float = 10000.0,\n        positional_embedding_max_pos: list[int] | None = [20, 2048, 2048],\n        timestep_scale_multiplier: int = 1000,\n        use_middle_indices_grid: bool = True,\n        audio_num_attention_heads: int = 32,\n        audio_attention_head_dim: int = 64,\n        audio_in_channels: int = 128,\n        audio_out_channels: int = 128,\n        audio_cross_attention_dim: int = 2048,\n        audio_positional_embedding_max_pos: list[int] | None = [20],\n        av_ca_timestep_scale_multiplier: int = 1000,\n        rope_type: LTXRopeType = LTXRopeType.SPLIT,\n        double_precision_rope: bool = True,\n        apply_gated_attention: bool = False,\n        cross_attention_adaln: bool = False,\n    ):\n        super().__init__()\n        self._enable_gradient_checkpointing = False\n        self.use_middle_indices_grid = use_middle_indices_grid\n        self.rope_type = rope_type\n        self.double_precision_rope = double_precision_rope\n        self.timestep_scale_multiplier = timestep_scale_multiplier\n        self.positional_embedding_theta = positional_embedding_theta\n        self.model_type = model_type\n        self.cross_attention_adaln = cross_attention_adaln\n        cross_pe_max_pos = None\n        if model_type.is_video_enabled():\n            if positional_embedding_max_pos is None:\n                positional_embedding_max_pos = [20, 2048, 2048]\n            self.positional_embedding_max_pos = positional_embedding_max_pos\n            self.num_attention_heads = num_attention_heads\n            self.inner_dim = num_attention_heads * attention_head_dim\n            self._init_video(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                caption_channels=caption_channels,\n                norm_eps=norm_eps,\n            )\n\n        if model_type.is_audio_enabled():\n            if audio_positional_embedding_max_pos is None:\n                audio_positional_embedding_max_pos = [20]\n            self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos\n            self.audio_num_attention_heads = audio_num_attention_heads\n            self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim\n            self._init_audio(\n                in_channels=audio_in_channels,\n                out_channels=audio_out_channels,\n                caption_channels=caption_channels,\n                norm_eps=norm_eps,\n            )\n\n        if model_type.is_video_enabled() and model_type.is_audio_enabled():\n            cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])\n            self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier\n            self.audio_cross_attention_dim = audio_cross_attention_dim\n            self._init_audio_video(num_scale_shift_values=4)\n\n        self._init_preprocessors(cross_pe_max_pos)\n        # Initialize transformer blocks\n        self._init_transformer_blocks(\n            num_layers=num_layers,\n            attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,\n            cross_attention_dim=cross_attention_dim,\n            audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,\n            audio_cross_attention_dim=audio_cross_attention_dim,\n            norm_eps=norm_eps,\n            apply_gated_attention=apply_gated_attention,\n        )\n\n    @property\n    def _adaln_embedding_coefficient(self) -> int:\n        return adaln_embedding_coefficient(self.cross_attention_adaln)\n\n    def _init_video(\n        self,\n        in_channels: int,\n        out_channels: int,\n        caption_channels: int,\n        norm_eps: float,\n    ) -> None:\n        \"\"\"Initialize video-specific components.\"\"\"\n        # Video input components\n        self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)\n        self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)\n        self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None\n\n        # Video caption projection\n        if caption_channels is not None:\n            self.caption_projection = PixArtAlphaTextProjection(\n                in_features=caption_channels,\n                hidden_size=self.inner_dim,\n            )\n\n        # Video output components\n        self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))\n        self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)\n        self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)\n\n    def _init_audio(\n        self,\n        in_channels: int,\n        out_channels: int,\n        caption_channels: int,\n        norm_eps: float,\n    ) -> None:\n        \"\"\"Initialize audio-specific components.\"\"\"\n\n        # Audio input components\n        self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)\n\n        self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)\n        self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None\n\n        # Audio caption projection\n        if caption_channels is not None:\n            self.audio_caption_projection = PixArtAlphaTextProjection(\n                in_features=caption_channels,\n                hidden_size=self.audio_inner_dim,\n            )\n\n        # Audio output components\n        self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))\n        self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)\n        self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)\n\n    def _init_audio_video(\n        self,\n        num_scale_shift_values: int,\n    ) -> None:\n        \"\"\"Initialize audio-video cross-attention components.\"\"\"\n        self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(\n            self.inner_dim,\n            embedding_coefficient=num_scale_shift_values,\n        )\n\n        self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(\n            self.audio_inner_dim,\n            embedding_coefficient=num_scale_shift_values,\n        )\n\n        self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(\n            self.inner_dim,\n            embedding_coefficient=1,\n        )\n\n        self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(\n            self.audio_inner_dim,\n            embedding_coefficient=1,\n        )\n\n    def _init_preprocessors(\n        self,\n        cross_pe_max_pos: int | None = None,\n    ) -> None:\n        \"\"\"Initialize preprocessors for LTX.\"\"\"\n\n        if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():\n            self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(\n                patchify_proj=self.patchify_proj,\n                adaln=self.adaln_single,\n                cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,\n                cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,\n                inner_dim=self.inner_dim,\n                max_pos=self.positional_embedding_max_pos,\n                num_attention_heads=self.num_attention_heads,\n                cross_pe_max_pos=cross_pe_max_pos,\n                use_middle_indices_grid=self.use_middle_indices_grid,\n                audio_cross_attention_dim=self.audio_cross_attention_dim,\n                timestep_scale_multiplier=self.timestep_scale_multiplier,\n                double_precision_rope=self.double_precision_rope,\n                positional_embedding_theta=self.positional_embedding_theta,\n                rope_type=self.rope_type,\n                av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,\n                caption_projection=getattr(self, \"caption_projection\", None),\n                prompt_adaln=getattr(self, \"prompt_adaln_single\", None),\n            )\n            self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(\n                patchify_proj=self.audio_patchify_proj,\n                adaln=self.audio_adaln_single,\n                cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,\n                cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,\n                inner_dim=self.audio_inner_dim,\n                max_pos=self.audio_positional_embedding_max_pos,\n                num_attention_heads=self.audio_num_attention_heads,\n                cross_pe_max_pos=cross_pe_max_pos,\n                use_middle_indices_grid=self.use_middle_indices_grid,\n                audio_cross_attention_dim=self.audio_cross_attention_dim,\n                timestep_scale_multiplier=self.timestep_scale_multiplier,\n                double_precision_rope=self.double_precision_rope,\n                positional_embedding_theta=self.positional_embedding_theta,\n                rope_type=self.rope_type,\n                av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,\n                caption_projection=getattr(self, \"audio_caption_projection\", None),\n                prompt_adaln=getattr(self, \"audio_prompt_adaln_single\", None),\n            )\n        elif self.model_type.is_video_enabled():\n            self.video_args_preprocessor = TransformerArgsPreprocessor(\n                patchify_proj=self.patchify_proj,\n                adaln=self.adaln_single,\n                inner_dim=self.inner_dim,\n                max_pos=self.positional_embedding_max_pos,\n                num_attention_heads=self.num_attention_heads,\n                use_middle_indices_grid=self.use_middle_indices_grid,\n                timestep_scale_multiplier=self.timestep_scale_multiplier,\n                double_precision_rope=self.double_precision_rope,\n                positional_embedding_theta=self.positional_embedding_theta,\n                rope_type=self.rope_type,\n                caption_projection=getattr(self, \"caption_projection\", None),\n                prompt_adaln=getattr(self, \"prompt_adaln_single\", None),\n            )\n        elif self.model_type.is_audio_enabled():\n            self.audio_args_preprocessor = TransformerArgsPreprocessor(\n                patchify_proj=self.audio_patchify_proj,\n                adaln=self.audio_adaln_single,\n                inner_dim=self.audio_inner_dim,\n                max_pos=self.audio_positional_embedding_max_pos,\n                num_attention_heads=self.audio_num_attention_heads,\n                use_middle_indices_grid=self.use_middle_indices_grid,\n                timestep_scale_multiplier=self.timestep_scale_multiplier,\n                double_precision_rope=self.double_precision_rope,\n                positional_embedding_theta=self.positional_embedding_theta,\n                rope_type=self.rope_type,\n                caption_projection=getattr(self, \"audio_caption_projection\", None),\n                prompt_adaln=getattr(self, \"audio_prompt_adaln_single\", None),\n            )\n\n    def _init_transformer_blocks(\n        self,\n        num_layers: int,\n        attention_head_dim: int,\n        cross_attention_dim: int,\n        audio_attention_head_dim: int,\n        audio_cross_attention_dim: int,\n        norm_eps: float,\n        apply_gated_attention: bool,\n    ) -> None:\n        \"\"\"Initialize transformer blocks for LTX.\"\"\"\n        video_config = (\n            TransformerConfig(\n                dim=self.inner_dim,\n                heads=self.num_attention_heads,\n                d_head=attention_head_dim,\n                context_dim=cross_attention_dim,\n                apply_gated_attention=apply_gated_attention,\n                cross_attention_adaln=self.cross_attention_adaln,\n            )\n            if self.model_type.is_video_enabled()\n            else None\n        )\n        audio_config = (\n            TransformerConfig(\n                dim=self.audio_inner_dim,\n                heads=self.audio_num_attention_heads,\n                d_head=audio_attention_head_dim,\n                context_dim=audio_cross_attention_dim,\n                apply_gated_attention=apply_gated_attention,\n                cross_attention_adaln=self.cross_attention_adaln,\n            )\n            if self.model_type.is_audio_enabled()\n            else None\n        )\n        self.transformer_blocks = torch.nn.ModuleList(\n            [\n                BasicAVTransformerBlock(\n                    idx=idx,\n                    video=video_config,\n                    audio=audio_config,\n                    rope_type=self.rope_type,\n                    norm_eps=norm_eps,\n                )\n                for idx in range(num_layers)\n            ]\n        )\n\n    def set_gradient_checkpointing(self, enable: bool) -> None:\n        \"\"\"Enable or disable gradient checkpointing for transformer blocks.\n        Gradient checkpointing trades compute for memory by recomputing activations\n        during the backward pass instead of storing them. This can significantly\n        reduce memory usage at the cost of ~20-30% slower training.\n        Args:\n            enable: Whether to enable gradient checkpointing\n        \"\"\"\n        self._enable_gradient_checkpointing = enable\n\n    def _process_transformer_blocks(\n        self,\n        video: TransformerArgs | None,\n        audio: TransformerArgs | None,\n        perturbations: BatchedPerturbationConfig,\n        use_gradient_checkpointing: bool = False,\n        use_gradient_checkpointing_offload: bool = False,\n    ) -> tuple[TransformerArgs, TransformerArgs]:\n        \"\"\"Process transformer blocks for LTXAV.\"\"\"\n\n        # Process transformer blocks\n        for block in self.transformer_blocks:\n            video, audio = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                video=video,\n                audio=audio,\n                perturbations=perturbations,\n            )\n\n        return video, audio\n\n    def _process_output(\n        self,\n        scale_shift_table: torch.Tensor,\n        norm_out: torch.nn.LayerNorm,\n        proj_out: torch.nn.Linear,\n        x: torch.Tensor,\n        embedded_timestep: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Process output for LTXV.\"\"\"\n        # Apply scale-shift modulation\n        scale_shift_values = (\n            scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]\n        )\n        shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]\n\n        x = norm_out(x)\n        x = x * (1 + scale) + shift\n        x = proj_out(x)\n        return x\n\n    def _forward(\n        self,\n        video: Modality | None,\n        audio: Modality | None,\n        perturbations: BatchedPerturbationConfig,\n        use_gradient_checkpointing: bool = False,\n        use_gradient_checkpointing_offload: bool = False,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forward pass for LTX models.\n        Returns:\n            Processed output tensors\n        \"\"\"\n        if not self.model_type.is_video_enabled() and video is not None:\n            raise ValueError(\"Video is not enabled for this model\")\n        if not self.model_type.is_audio_enabled() and audio is not None:\n            raise ValueError(\"Audio is not enabled for this model\")\n\n        video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None\n        audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None\n        # Process transformer blocks\n        video_out, audio_out = self._process_transformer_blocks(\n            video=video_args,\n            audio=audio_args,\n            perturbations=perturbations,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n\n        # Process output\n        vx = (\n            self._process_output(\n                self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep\n            )\n            if video_out is not None\n            else None\n        )\n        ax = (\n            self._process_output(\n                self.audio_scale_shift_table,\n                self.audio_norm_out,\n                self.audio_proj_out,\n                audio_out.x,\n                audio_out.embedded_timestep,\n            )\n            if audio_out is not None\n            else None\n        )\n        return vx, ax\n\n    def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):\n        cross_pe_max_pos = None\n        if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():\n            cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])\n        self._init_preprocessors(cross_pe_max_pos)\n        video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)\n        audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None\n        vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)\n        return vx, ax\n"
  },
  {
    "path": "diffsynth/models/ltx2_text_encoder.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer\nfrom .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,\n                       FeedForward)\nfrom .ltx2_common import rms_norm\n\n\nclass LTX2TextEncoder(Gemma3ForConditionalGeneration):\n    def __init__(self):\n        config = Gemma3Config(\n            **{\n                \"architectures\": [\"Gemma3ForConditionalGeneration\"],\n                \"boi_token_index\": 255999,\n                \"dtype\": \"bfloat16\",\n                \"eoi_token_index\": 256000,\n                \"eos_token_id\": [1, 106],\n                \"image_token_index\": 262144,\n                \"initializer_range\": 0.02,\n                \"mm_tokens_per_image\": 256,\n                \"model_type\": \"gemma3\",\n                \"text_config\": {\n                    \"_sliding_window_pattern\": 6,\n                    \"attention_bias\": False,\n                    \"attention_dropout\": 0.0,\n                    \"attn_logit_softcapping\": None,\n                    \"cache_implementation\": \"hybrid\",\n                    \"dtype\": \"bfloat16\",\n                    \"final_logit_softcapping\": None,\n                    \"head_dim\": 256,\n                    \"hidden_activation\": \"gelu_pytorch_tanh\",\n                    \"hidden_size\": 3840,\n                    \"initializer_range\": 0.02,\n                    \"intermediate_size\": 15360,\n                    \"layer_types\": [\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"full_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"full_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"full_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"full_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"full_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"full_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"full_attention\", \"sliding_attention\", \"sliding_attention\",\n                        \"sliding_attention\", \"sliding_attention\", \"sliding_attention\", \"full_attention\"\n                    ],\n                    \"max_position_embeddings\": 131072,\n                    \"model_type\": \"gemma3_text\",\n                    \"num_attention_heads\": 16,\n                    \"num_hidden_layers\": 48,\n                    \"num_key_value_heads\": 8,\n                    \"query_pre_attn_scalar\": 256,\n                    \"rms_norm_eps\": 1e-06,\n                    \"rope_local_base_freq\": 10000,\n                    \"rope_scaling\": {\n                        \"factor\": 8.0,\n                        \"rope_type\": \"linear\"\n                    },\n                    \"rope_theta\": 1000000,\n                    \"sliding_window\": 1024,\n                    \"sliding_window_pattern\": 6,\n                    \"use_bidirectional_attention\": False,\n                    \"use_cache\": True,\n                    \"vocab_size\": 262208\n                },\n                \"transformers_version\": \"4.57.3\",\n                \"vision_config\": {\n                    \"attention_dropout\": 0.0,\n                    \"dtype\": \"bfloat16\",\n                    \"hidden_act\": \"gelu_pytorch_tanh\",\n                    \"hidden_size\": 1152,\n                    \"image_size\": 896,\n                    \"intermediate_size\": 4304,\n                    \"layer_norm_eps\": 1e-06,\n                    \"model_type\": \"siglip_vision_model\",\n                    \"num_attention_heads\": 16,\n                    \"num_channels\": 3,\n                    \"num_hidden_layers\": 27,\n                    \"patch_size\": 14,\n                    \"vision_use_head\": False\n                }\n            })\n        super().__init__(config)\n\n\nclass LTXVGemmaTokenizer:\n    \"\"\"\n    Tokenizer wrapper for Gemma models compatible with LTXV processes.\n    This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,\n    ensuring correct settings and output formatting for downstream consumption.\n    \"\"\"\n\n    def __init__(self, tokenizer_path: str, max_length: int = 1024):\n        \"\"\"\n        Initialize the tokenizer.\n        Args:\n            tokenizer_path (str): Path to the pretrained tokenizer files or model directory.\n            max_length (int, optional): Max sequence length for encoding. Defaults to 256.\n        \"\"\"\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            tokenizer_path, local_files_only=True, model_max_length=max_length\n        )\n        # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.\n        self.tokenizer.padding_side = \"left\"\n        if self.tokenizer.pad_token is None:\n            self.tokenizer.pad_token = self.tokenizer.eos_token\n\n        self.max_length = max_length\n\n    def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:\n        \"\"\"\n        Tokenize the given text and return token IDs and attention weights.\n        Args:\n            text (str): The input string to tokenize.\n            return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.\n                                              If False (default), omits the indices.\n        Returns:\n            dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:\n                A dictionary with a \"gemma\" key mapping to:\n                    - a list of (token_id, attention_mask) tuples if return_word_ids is False;\n                    - a list of (token_id, attention_mask, index) tuples if return_word_ids is True.\n        Example:\n            >>> tokenizer = LTXVGemmaTokenizer(\"path/to/tokenizer\", max_length=8)\n            >>> tokenizer.tokenize_with_weights(\"hello world\")\n            {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}\n        \"\"\"\n        text = text.strip()\n        encoded = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            max_length=self.max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        input_ids = encoded.input_ids\n        attention_mask = encoded.attention_mask\n        tuples = [\n            (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))\n        ]\n        out = {\"gemma\": tuples}\n\n        if not return_word_ids:\n            # Return only (token_id, attention_mask) pairs, omitting token position\n            out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}\n\n        return out\n\n\nclass GemmaFeaturesExtractorProjLinear(nn.Module):\n    \"\"\"\n    Feature extractor module for Gemma models.\n    This module applies a single linear projection to the input tensor.\n    It expects a flattened feature tensor of shape (batch_size, 3840*49).\n    The linear layer maps this to a (batch_size, 3840) embedding.\n    Attributes:\n        aggregate_embed (nn.Linear): Linear projection layer.\n    \"\"\"\n\n    def __init__(self) -> None:\n        \"\"\"\n        Initialize the GemmaFeaturesExtractorProjLinear module.\n        The input dimension is expected to be 3840 * 49, and the output is 3840.\n        \"\"\"\n        super().__init__()\n        self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        padding_side: str = \"left\",\n    ) -> tuple[torch.Tensor, torch.Tensor | None]:\n        encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states\n        dtype = encoded.dtype\n        sequence_lengths = attention_mask.sum(dim=-1)\n        normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)\n        features = self.aggregate_embed(normed.to(dtype))\n        return features, features\n\n\nclass GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):\n    \"\"\"22B: per-token RMS norm → rescale → dual aggregate embeds\"\"\"\n\n    def __init__(\n        self,\n        num_layers: int,\n        embedding_dim: int,\n        video_inner_dim: int,\n        audio_inner_dim: int,\n    ):\n        super().__init__()\n        in_dim = embedding_dim * num_layers\n        self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)\n        self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)\n        self.embedding_dim = embedding_dim\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        padding_side: str = \"left\",  # noqa: ARG002\n    ) -> tuple[torch.Tensor, torch.Tensor | None]:\n        encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states\n        normed = norm_and_concat_per_token_rms(encoded, attention_mask)\n        normed = normed.to(encoded.dtype)\n        v_dim = self.video_aggregate_embed.out_features\n        video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))\n        audio = None\n        if self.audio_aggregate_embed is not None:\n            a_dim = self.audio_aggregate_embed.out_features\n            audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))\n        return video, audio\n\n\n\nclass _BasicTransformerBlock1D(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        heads: int,\n        dim_head: int,\n        rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,\n        apply_gated_attention: bool = False,\n    ):\n        super().__init__()\n\n        self.attn1 = Attention(\n            query_dim=dim,\n            heads=heads,\n            dim_head=dim_head,\n            rope_type=rope_type,\n            apply_gated_attention=apply_gated_attention,\n        )\n\n        self.ff = FeedForward(\n            dim,\n            dim_out=dim,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n        pe: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        # Notice that normalization is always applied before the real computation in the following blocks.\n\n        # 1. Normalization Before Self-Attention\n        norm_hidden_states = rms_norm(hidden_states)\n\n        norm_hidden_states = norm_hidden_states.squeeze(1)\n\n        # 2. Self-Attention\n        attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)\n\n        hidden_states = attn_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        # 3. Normalization before Feed-Forward\n        norm_hidden_states = rms_norm(hidden_states)\n\n        # 4. Feed-forward\n        ff_output = self.ff(norm_hidden_states)\n\n        hidden_states = ff_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        return hidden_states\n\n\nclass Embeddings1DConnector(nn.Module):\n    \"\"\"\n    Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or\n    other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can\n    substitute padded positions with learnable registers. The module is highly configurable for head size, number of\n    layers, and register usage.\n    Args:\n        attention_head_dim (int): Dimension of each attention head (default=128).\n        num_attention_heads (int): Number of attention heads (default=30).\n        num_layers (int): Number of transformer layers (default=2).\n        positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).\n        positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).\n        causal_temporal_positioning (bool): If True, uses causal attention (default=False).\n        num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables\n            register replacement. (default=128)\n        rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).\n        double_precision_rope (bool): Use double precision rope calculation (default=False).\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        attention_head_dim: int = 128,\n        num_attention_heads: int = 30,\n        num_layers: int = 2,\n        positional_embedding_theta: float = 10000.0,\n        positional_embedding_max_pos: list[int] | None = [4096],\n        causal_temporal_positioning: bool = False,\n        num_learnable_registers: int | None = 128,\n        rope_type: LTXRopeType = LTXRopeType.SPLIT,\n        double_precision_rope: bool = True,\n        apply_gated_attention: bool = False,\n    ):\n        super().__init__()\n        self.num_attention_heads = num_attention_heads\n        self.inner_dim = num_attention_heads * attention_head_dim\n        self.causal_temporal_positioning = causal_temporal_positioning\n        self.positional_embedding_theta = positional_embedding_theta\n        self.positional_embedding_max_pos = (\n            positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]\n        )\n        self.rope_type = rope_type\n        self.double_precision_rope = double_precision_rope\n        self.transformer_1d_blocks = nn.ModuleList(\n            [\n                _BasicTransformerBlock1D(\n                    dim=self.inner_dim,\n                    heads=num_attention_heads,\n                    dim_head=attention_head_dim,\n                    rope_type=rope_type,\n                    apply_gated_attention=apply_gated_attention,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        self.num_learnable_registers = num_learnable_registers\n        if self.num_learnable_registers:\n            self.learnable_registers = nn.Parameter(\n                torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0\n            )\n\n    def _replace_padded_with_learnable_registers(\n        self, hidden_states: torch.Tensor, attention_mask: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        assert hidden_states.shape[1] % self.num_learnable_registers == 0, (\n            f\"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers \"\n            f\"{self.num_learnable_registers}.\"\n        )\n\n        num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers\n        learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))\n        attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()\n\n        non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]\n        non_zero_nums = non_zero_hidden_states.shape[1]\n        pad_length = hidden_states.shape[1] - non_zero_nums\n        adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)\n        flipped_mask = torch.flip(attention_mask_binary, dims=[1])\n        hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers\n\n        attention_mask = torch.full_like(\n            attention_mask,\n            0.0,\n            dtype=attention_mask.dtype,\n            device=attention_mask.device,\n        )\n\n        return hidden_states, attention_mask\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forward pass of Embeddings1DConnector.\n        Args:\n            hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).\n            attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).\n        Returns:\n            tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.\n        \"\"\"\n        if self.num_learnable_registers:\n            hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)\n\n        indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)\n        indices_grid = indices_grid[None, None, :]\n        freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch\n        freqs_cis = precompute_freqs_cis(\n            indices_grid=indices_grid,\n            dim=self.inner_dim,\n            out_dtype=hidden_states.dtype,\n            theta=self.positional_embedding_theta,\n            max_pos=self.positional_embedding_max_pos,\n            num_attention_heads=self.num_attention_heads,\n            rope_type=self.rope_type,\n            freq_grid_generator=freq_grid_generator,\n        )\n\n        for block in self.transformer_1d_blocks:\n            hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)\n\n        hidden_states = rms_norm(hidden_states)\n\n        return hidden_states, attention_mask\n\n\nclass LTX2TextEncoderPostModules(nn.Module):\n    def __init__(\n        self,\n        separated_audio_video: bool = False,\n        embedding_dim_gemma: int = 3840,\n        num_layers_gemma: int = 49,\n        video_attention_heads: int = 32,\n        video_attention_head_dim: int = 128,\n        audio_attention_heads: int = 32,\n        audio_attention_head_dim: int = 64,\n        num_connector_layers: int = 2,\n        apply_gated_attention: bool = False,\n    ):\n        super().__init__()\n        if not separated_audio_video:\n            self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()\n            self.embeddings_connector = Embeddings1DConnector()\n            self.audio_embeddings_connector = Embeddings1DConnector()\n        else:\n            # LTX-2.3\n            self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(\n                num_layers_gemma, embedding_dim_gemma, video_attention_heads * video_attention_head_dim,\n                audio_attention_heads * audio_attention_head_dim)\n            self.embeddings_connector = Embeddings1DConnector(\n                attention_head_dim=video_attention_head_dim,\n                num_attention_heads=video_attention_heads,\n                num_layers=num_connector_layers,\n                apply_gated_attention=apply_gated_attention,\n            )\n            self.audio_embeddings_connector = Embeddings1DConnector(\n                attention_head_dim=audio_attention_head_dim,\n                num_attention_heads=audio_attention_heads,\n                num_layers=num_connector_layers,\n                apply_gated_attention=apply_gated_attention,\n            )\n\n    def create_embeddings(\n        self,\n        video_features: torch.Tensor,\n        audio_features: torch.Tensor | None,\n        additive_attention_mask: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:\n        video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)\n        video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)\n        audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)\n\n        return video_encoded, audio_encoded, binary_mask\n\n    def process_hidden_states(\n        self,\n        hidden_states: tuple[torch.Tensor, ...],\n        attention_mask: torch.Tensor,\n        padding_side: str = \"left\",\n    ):\n        video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)\n        additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)\n        video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)\n        return video_enc, audio_enc, binary_mask\n\n\ndef _norm_and_concat_padded_batch(\n    encoded_text: torch.Tensor,\n    sequence_lengths: torch.Tensor,\n    padding_side: str = \"right\",\n) -> torch.Tensor:\n    \"\"\"Normalize and flatten multi-layer hidden states, respecting padding.\n    Performs per-batch, per-layer normalization using masked mean and range,\n    then concatenates across the layer dimension.\n    Args:\n        encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].\n        sequence_lengths: Number of valid (non-padded) tokens per batch item.\n        padding_side: Whether padding is on \"left\" or \"right\".\n    Returns:\n        Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],\n        with padded positions zeroed out.\n    \"\"\"\n    b, t, d, l = encoded_text.shape  # noqa: E741\n    device = encoded_text.device\n    # Build mask: [B, T, 1, 1]\n    token_indices = torch.arange(t, device=device)[None, :]  # [1, T]\n    if padding_side == \"right\":\n        # For right padding, valid tokens are from 0 to sequence_length-1\n        mask = token_indices < sequence_lengths[:, None]  # [B, T]\n    elif padding_side == \"left\":\n        # For left padding, valid tokens are from (T - sequence_length) to T-1\n        start_indices = t - sequence_lengths[:, None]  # [B, 1]\n        mask = token_indices >= start_indices  # [B, T]\n    else:\n        raise ValueError(f\"padding_side must be 'left' or 'right', got {padding_side}\")\n    mask = rearrange(mask, \"b t -> b t 1 1\")\n    eps = 1e-6\n    # Compute masked mean: [B, 1, 1, L]\n    masked = encoded_text.masked_fill(~mask, 0.0)\n    denom = (sequence_lengths * d).view(b, 1, 1, 1)\n    mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)\n    # Compute masked min/max: [B, 1, 1, L]\n    x_min = encoded_text.masked_fill(~mask, float(\"inf\")).amin(dim=(1, 2), keepdim=True)\n    x_max = encoded_text.masked_fill(~mask, float(\"-inf\")).amax(dim=(1, 2), keepdim=True)\n    range_ = x_max - x_min\n    # Normalize only the valid tokens\n    normed = 8 * (encoded_text - mean) / (range_ + eps)\n    # concat to be [Batch, T,  D * L] - this preserves the original structure\n    normed = normed.reshape(b, t, -1)  # [B, T, D * L]\n    # Apply mask to preserve original padding (set padded positions to 0)\n    mask_flattened = rearrange(mask, \"b t 1 1 -> b t 1\").expand(-1, -1, d * l)\n    normed = normed.masked_fill(~mask_flattened, 0.0)\n\n    return normed\n\n\ndef _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:\n    return (attention_mask - 1).to(dtype).reshape(\n        (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max\n\ndef _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Convert connector output mask to binary mask and apply to encoded tensor.\"\"\"\n    binary_mask = (encoded_mask < 0.000001).to(torch.int64)\n    binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])\n    encoded = encoded * binary_mask\n    return encoded, binary_mask\n\n\ndef norm_and_concat_per_token_rms(\n    encoded_text: torch.Tensor,\n    attention_mask: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Per-token RMSNorm normalization for V2 models.\n    Args:\n        encoded_text: [B, T, D, L]\n        attention_mask: [B, T] binary mask\n    Returns:\n        [B, T, D*L] normalized tensor with padding zeroed out.\n    \"\"\"\n    B, T, D, L = encoded_text.shape  # noqa: N806\n    variance = torch.mean(encoded_text**2, dim=2, keepdim=True)  # [B,T,1,L]\n    normed = encoded_text * torch.rsqrt(variance + 1e-6)\n    normed = normed.reshape(B, T, D * L)\n    mask_3d = attention_mask.bool().unsqueeze(-1)  # [B, T, 1]\n    return torch.where(mask_3d, normed, torch.zeros_like(normed))\n\n\ndef _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:\n    \"\"\"Rescale normalization: x * sqrt(target_dim / source_dim).\"\"\"\n    return x * math.sqrt(target_dim / source_dim)\n"
  },
  {
    "path": "diffsynth/models/ltx2_upsampler.py",
    "content": "import math\nfrom typing import Optional, Tuple\nimport torch\nfrom einops import rearrange\nimport torch.nn.functional as F\nfrom .ltx2_video_vae import LTX2VideoEncoder\n\nclass PixelShuffleND(torch.nn.Module):\n    \"\"\"\n    N-dimensional pixel shuffle operation for upsampling tensors.\n    Args:\n        dims (int): Number of dimensions to apply pixel shuffle to.\n            - 1: Temporal (e.g., frames)\n            - 2: Spatial (e.g., height and width)\n            - 3: Spatiotemporal (e.g., depth, height, width)\n        upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.\n            For dims=1, only the first value is used.\n            For dims=2, the first two values are used.\n            For dims=3, all three values are used.\n    The input tensor is rearranged so that the channel dimension is split into\n    smaller channels and upscaling factors, and the upscaling factors are moved\n    into the corresponding spatial/temporal dimensions.\n    Note:\n    This operation is equivalent to the patchifier operation in for the models. Consider\n    using this class instead.\n    \"\"\"\n\n    def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):\n        super().__init__()\n        assert dims in [1, 2, 3], \"dims must be 1, 2, or 3\"\n        self.dims = dims\n        self.upscale_factors = upscale_factors\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.dims == 3:\n            return rearrange(\n                x,\n                \"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)\",\n                p1=self.upscale_factors[0],\n                p2=self.upscale_factors[1],\n                p3=self.upscale_factors[2],\n            )\n        elif self.dims == 2:\n            return rearrange(\n                x,\n                \"b (c p1 p2) h w -> b c (h p1) (w p2)\",\n                p1=self.upscale_factors[0],\n                p2=self.upscale_factors[1],\n            )\n        elif self.dims == 1:\n            return rearrange(\n                x,\n                \"b (c p1) f h w -> b c (f p1) h w\",\n                p1=self.upscale_factors[0],\n            )\n        else:\n            raise ValueError(f\"Unsupported dims: {self.dims}\")\n\n\nclass ResBlock(torch.nn.Module):\n    \"\"\"\n    Residual block with two convolutional layers, group normalization, and SiLU activation.\n    Args:\n        channels (int): Number of input and output channels.\n        mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`\n        if not specified.\n        dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.\n    \"\"\"\n\n    def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):\n        super().__init__()\n        if mid_channels is None:\n            mid_channels = channels\n\n        conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d\n\n        self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)\n        self.norm1 = torch.nn.GroupNorm(32, mid_channels)\n        self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)\n        self.norm2 = torch.nn.GroupNorm(32, channels)\n        self.activation = torch.nn.SiLU()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        residual = x\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.activation(x)\n        x = self.conv2(x)\n        x = self.norm2(x)\n        x = self.activation(x + residual)\n        return x\n\n\nclass BlurDownsample(torch.nn.Module):\n    \"\"\"\n    Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.\n    Applies only on H,W. Works for dims=2 or dims=3 (per-frame).\n    \"\"\"\n\n    def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:\n        super().__init__()\n        assert dims in (2, 3)\n        assert isinstance(stride, int)\n        assert stride >= 1\n        assert kernel_size >= 3\n        assert kernel_size % 2 == 1\n        self.dims = dims\n        self.stride = stride\n        self.kernel_size = kernel_size\n\n        # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from\n        # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and\n        # provides a smooth approximation of a Gaussian filter (often called a \"binomial filter\").\n        # The 2D kernel is constructed as the outer product and normalized.\n        k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])\n        k2d = k[:, None] @ k[None, :]\n        k2d = (k2d / k2d.sum()).float()  # shape (kernel_size, kernel_size)\n        self.register_buffer(\"kernel\", k2d[None, None, :, :])  # (1, 1, kernel_size, kernel_size)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.stride == 1:\n            return x\n\n        if self.dims == 2:\n            return self._apply_2d(x)\n        else:\n            # dims == 3: apply per-frame on H,W\n            b, _, f, _, _ = x.shape\n            x = rearrange(x, \"b c f h w -> (b f) c h w\")\n            x = self._apply_2d(x)\n            h2, w2 = x.shape[-2:]\n            x = rearrange(x, \"(b f) c h w -> b c f h w\", b=b, f=f, h=h2, w=w2)\n            return x\n\n    def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:\n        c = x2d.shape[1]\n        weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size)  # depthwise\n        x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)\n        return x2d\n\n\ndef _rational_for_scale(scale: float) -> Tuple[int, int]:\n    mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}\n    if float(scale) not in mapping:\n        raise ValueError(f\"Unsupported scale {scale}. Choose from {list(mapping.keys())}\")\n    return mapping[float(scale)]\n\n\nclass SpatialRationalResampler(torch.nn.Module):\n    \"\"\"\n    Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased\n    downsample by 'den' using fixed blur + stride. Operates on H,W only.\n    For dims==3, work per-frame for spatial scaling (temporal axis untouched).\n    Args:\n        mid_channels (`int`): Number of intermediate channels for the convolution layer\n        scale (`float`): Spatial scaling factor. Supported values are:\n            - 0.75: Downsample by 3/4 (reduce spatial size)\n            - 1.5: Upsample by 3/2 (increase spatial size)\n            - 2.0: Upsample by 2x (double spatial size)\n            - 4.0: Upsample by 4x (quadruple spatial size)\n            Any other value will raise a ValueError.\n    \"\"\"\n\n    def __init__(self, mid_channels: int, scale: float):\n        super().__init__()\n        self.scale = float(scale)\n        self.num, self.den = _rational_for_scale(self.scale)\n        self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)\n        self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))\n        self.blur_down = BlurDownsample(dims=2, stride=self.den)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        b, _, f, _, _ = x.shape\n        x = rearrange(x, \"b c f h w -> (b f) c h w\")\n        x = self.conv(x)\n        x = self.pixel_shuffle(x)\n        x = self.blur_down(x)\n        x = rearrange(x, \"(b f) c h w -> b c f h w\", b=b, f=f)\n        return x\n\n\nclass LTX2LatentUpsampler(torch.nn.Module):\n    \"\"\"\n    Model to upsample VAE latents spatially and/or temporally.\n    Args:\n        in_channels (`int`): Number of channels in the input latent\n        mid_channels (`int`): Number of channels in the middle layers\n        num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)\n        dims (`int`): Number of dimensions for convolutions (2 or 3)\n        spatial_upsample (`bool`): Whether to spatially upsample the latent\n        temporal_upsample (`bool`): Whether to temporally upsample the latent\n        spatial_scale (`float`): Scale factor for spatial upsampling\n        rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int = 128,\n        mid_channels: int = 1024,\n        num_blocks_per_stage: int = 4,\n        dims: int = 3,\n        spatial_upsample: bool = True,\n        temporal_upsample: bool = False,\n        spatial_scale: float = 2.0,\n        rational_resampler: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.mid_channels = mid_channels\n        self.num_blocks_per_stage = num_blocks_per_stage\n        self.dims = dims\n        self.spatial_upsample = spatial_upsample\n        self.temporal_upsample = temporal_upsample\n        self.spatial_scale = float(spatial_scale)\n        self.rational_resampler = rational_resampler\n\n        conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d\n\n        self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)\n        self.initial_norm = torch.nn.GroupNorm(32, mid_channels)\n        self.initial_activation = torch.nn.SiLU()\n\n        self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])\n\n        if spatial_upsample and temporal_upsample:\n            self.upsampler = torch.nn.Sequential(\n                torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),\n                PixelShuffleND(3),\n            )\n        elif spatial_upsample:\n            if rational_resampler:\n                self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)\n            else:\n                self.upsampler = torch.nn.Sequential(\n                    torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),\n                    PixelShuffleND(2),\n                )\n        elif temporal_upsample:\n            self.upsampler = torch.nn.Sequential(\n                torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),\n                PixelShuffleND(1),\n            )\n        else:\n            raise ValueError(\"Either spatial_upsample or temporal_upsample must be True\")\n\n        self.post_upsample_res_blocks = torch.nn.ModuleList(\n            [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]\n        )\n\n        self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)\n\n    def forward(self, latent: torch.Tensor) -> torch.Tensor:\n        b, _, f, _, _ = latent.shape\n\n        if self.dims == 2:\n            x = rearrange(latent, \"b c f h w -> (b f) c h w\")\n            x = self.initial_conv(x)\n            x = self.initial_norm(x)\n            x = self.initial_activation(x)\n\n            for block in self.res_blocks:\n                x = block(x)\n\n            x = self.upsampler(x)\n\n            for block in self.post_upsample_res_blocks:\n                x = block(x)\n\n            x = self.final_conv(x)\n            x = rearrange(x, \"(b f) c h w -> b c f h w\", b=b, f=f)\n        else:\n            x = self.initial_conv(latent)\n            x = self.initial_norm(x)\n            x = self.initial_activation(x)\n\n            for block in self.res_blocks:\n                x = block(x)\n\n            if self.temporal_upsample:\n                x = self.upsampler(x)\n                # remove the first frame after upsampling.\n                # This is done because the first frame encodes one pixel frame.\n                x = x[:, :, 1:, :, :]\n            elif isinstance(self.upsampler, SpatialRationalResampler):\n                x = self.upsampler(x)\n            else:\n                x = rearrange(x, \"b c f h w -> (b f) c h w\")\n                x = self.upsampler(x)\n                x = rearrange(x, \"(b f) c h w -> b c f h w\", b=b, f=f)\n\n            for block in self.post_upsample_res_blocks:\n                x = block(x)\n\n            x = self.final_conv(x)\n\n        return x\n\n\ndef upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: \"LTX2LatentUpsampler\") -> torch.Tensor:\n    \"\"\"\n    Apply upsampling to the latent representation using the provided upsampler,\n    with normalization and un-normalization based on the video encoder's per-channel statistics.\n    Args:\n        latent: Input latent tensor of shape [B, C, F, H, W].\n        video_encoder: VideoEncoder with per_channel_statistics for normalization.\n        upsampler: LTX2LatentUpsampler module to perform upsampling.\n    Returns:\n        torch.Tensor: Upsampled and re-normalized latent tensor.\n    \"\"\"\n    latent = video_encoder.per_channel_statistics.un_normalize(latent)\n    latent = upsampler(latent)\n    latent = video_encoder.per_channel_statistics.normalize(latent)\n    return latent\n"
  },
  {
    "path": "diffsynth/models/ltx2_video_vae.py",
    "content": "import itertools\nimport math\nimport einops\nfrom dataclasses import replace, dataclass\nfrom typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional\nimport torch\nfrom einops import rearrange\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom enum import Enum\nfrom .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape\nfrom .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings\n\nVAE_SPATIAL_FACTOR = 32\nVAE_TEMPORAL_FACTOR = 8\n\n\nclass VideoLatentPatchifier(Patchifier):\n    def __init__(self, patch_size: int):\n        # Patch sizes for video latents.\n        self._patch_size = (\n            1,  # temporal dimension\n            patch_size,  # height dimension\n            patch_size,  # width dimension\n        )\n\n    @property\n    def patch_size(self) -> Tuple[int, int, int]:\n        return self._patch_size\n\n    def get_token_count(self, tgt_shape: VideoLatentShape) -> int:\n        return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)\n\n    def patchify(\n        self,\n        latents: torch.Tensor,\n    ) -> torch.Tensor:\n        latents = einops.rearrange(\n            latents,\n            \"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)\",\n            p1=self._patch_size[0],\n            p2=self._patch_size[1],\n            p3=self._patch_size[2],\n        )\n\n        return latents\n\n    def unpatchify(\n        self,\n        latents: torch.Tensor,\n        output_shape: VideoLatentShape,\n    ) -> torch.Tensor:\n        assert self._patch_size[0] == 1, \"Temporal patch size must be 1 for symmetric patchifier\"\n\n        patch_grid_frames = output_shape.frames // self._patch_size[0]\n        patch_grid_height = output_shape.height // self._patch_size[1]\n        patch_grid_width = output_shape.width // self._patch_size[2]\n\n        latents = einops.rearrange(\n            latents,\n            \"b (f h w) (c p q) -> b c f (h p) (w q)\",\n            f=patch_grid_frames,\n            h=patch_grid_height,\n            w=patch_grid_width,\n            p=self._patch_size[1],\n            q=self._patch_size[2],\n        )\n\n        return latents\n\n    def unpatchify_video(\n        self,\n        latents: torch.Tensor,\n        frames: int,\n        height: int,\n        width: int,\n    ) -> torch.Tensor:\n        latents = einops.rearrange(\n            latents,\n            \"b (f h w) (c p q) -> b c f (h p) (w q)\",\n            f=frames,\n            h=height // self._patch_size[1],\n            w=width // self._patch_size[2],\n            p=self._patch_size[1],\n            q=self._patch_size[2],\n        )\n        return latents\n\n    def get_patch_grid_bounds(\n        self,\n        output_shape: AudioLatentShape | VideoLatentShape,\n        device: Optional[torch.device] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Return the per-dimension bounds [inclusive start, exclusive end) for every\n        patch produced by `patchify`. The bounds are expressed in the original\n        video grid coordinates: frame/time, height, and width.\n        The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:\n            - axis 1 (size 3) enumerates (frame/time, height, width) dimensions\n            - axis 3 (size 2) stores `[start, end)` indices within each dimension\n        Args:\n            output_shape: Video grid description containing frames, height, and width.\n            device: Device of the latent tensor.\n        \"\"\"\n        if not isinstance(output_shape, VideoLatentShape):\n            raise ValueError(\"VideoLatentPatchifier expects VideoLatentShape when computing coordinates\")\n\n        frames = output_shape.frames\n        height = output_shape.height\n        width = output_shape.width\n        batch_size = output_shape.batch\n\n        # Validate inputs to ensure positive dimensions\n        assert frames > 0, f\"frames must be positive, got {frames}\"\n        assert height > 0, f\"height must be positive, got {height}\"\n        assert width > 0, f\"width must be positive, got {width}\"\n        assert batch_size > 0, f\"batch_size must be positive, got {batch_size}\"\n\n        # Generate grid coordinates for each dimension (frame, height, width)\n        # We use torch.arange to create the starting coordinates for each patch.\n        # indexing='ij' ensures the dimensions are in the order (frame, height, width).\n        grid_coords = torch.meshgrid(\n            torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),\n            torch.arange(start=0, end=height, step=self._patch_size[1], device=device),\n            torch.arange(start=0, end=width, step=self._patch_size[2], device=device),\n            indexing=\"ij\",\n        )\n\n        # Stack the grid coordinates to create the start coordinates tensor.\n        # Shape becomes (3, grid_f, grid_h, grid_w)\n        patch_starts = torch.stack(grid_coords, dim=0)\n\n        # Create a tensor containing the size of a single patch:\n        # (frame_patch_size, height_patch_size, width_patch_size).\n        # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.\n        patch_size_delta = torch.tensor(\n            self._patch_size,\n            device=patch_starts.device,\n            dtype=patch_starts.dtype,\n        ).view(3, 1, 1, 1)\n\n        # Calculate end coordinates: start + patch_size\n        # Shape becomes (3, grid_f, grid_h, grid_w)\n        patch_ends = patch_starts + patch_size_delta\n\n        # Stack start and end coordinates together along the last dimension\n        # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]\n        latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)\n\n        # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.\n        # Final Shape: (batch_size, 3, num_patches, 2)\n        latent_coords = einops.repeat(\n            latent_coords,\n            \"c f h w bounds -> b c (f h w) bounds\",\n            b=batch_size,\n            bounds=2,\n        )\n\n        return latent_coords\n\n\nclass NormLayerType(Enum):\n    GROUP_NORM = \"group_norm\"\n    PIXEL_NORM = \"pixel_norm\"\n\n\nclass LogVarianceType(Enum):\n    PER_CHANNEL = \"per_channel\"\n    UNIFORM = \"uniform\"\n    CONSTANT = \"constant\"\n    NONE = \"none\"\n\n\nclass PaddingModeType(Enum):\n    ZEROS = \"zeros\"\n    REFLECT = \"reflect\"\n    REPLICATE = \"replicate\"\n    CIRCULAR = \"circular\"\n\n\nclass DualConv3d(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        padding: Union[int, Tuple[int, int, int]] = 0,\n        dilation: Union[int, Tuple[int, int, int]] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n    ) -> None:\n        super(DualConv3d, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.padding_mode = padding_mode\n        # Ensure kernel_size, stride, padding, and dilation are tuples of length 3\n        if isinstance(kernel_size, int):\n            kernel_size = (kernel_size, kernel_size, kernel_size)\n        if kernel_size == (1, 1, 1):\n            raise ValueError(\"kernel_size must be greater than 1. Use make_linear_nd instead.\")\n        if isinstance(stride, int):\n            stride = (stride, stride, stride)\n        if isinstance(padding, int):\n            padding = (padding, padding, padding)\n        if isinstance(dilation, int):\n            dilation = (dilation, dilation, dilation)\n\n        # Set parameters for convolutions\n        self.groups = groups\n        self.bias = bias\n\n        # Define the size of the channels after the first convolution\n        intermediate_channels = out_channels if in_channels < out_channels else in_channels\n\n        # Define parameters for the first convolution\n        self.weight1 = nn.Parameter(\n            torch.Tensor(\n                intermediate_channels,\n                in_channels // groups,\n                1,\n                kernel_size[1],\n                kernel_size[2],\n            ))\n        self.stride1 = (1, stride[1], stride[2])\n        self.padding1 = (0, padding[1], padding[2])\n        self.dilation1 = (1, dilation[1], dilation[2])\n        if bias:\n            self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))\n        else:\n            self.register_parameter(\"bias1\", None)\n\n        # Define parameters for the second convolution\n        self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))\n        self.stride2 = (stride[0], 1, 1)\n        self.padding2 = (padding[0], 0, 0)\n        self.dilation2 = (dilation[0], 1, 1)\n        if bias:\n            self.bias2 = nn.Parameter(torch.Tensor(out_channels))\n        else:\n            self.register_parameter(\"bias2\", None)\n\n        # Initialize weights and biases\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5))\n        nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5))\n        if self.bias:\n            fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)\n            bound1 = 1 / torch.sqrt(fan_in1)\n            nn.init.uniform_(self.bias1, -bound1, bound1)\n            fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)\n            bound2 = 1 / torch.sqrt(fan_in2)\n            nn.init.uniform_(self.bias2, -bound2, bound2)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        use_conv3d: bool = False,\n        skip_time_conv: bool = False,\n    ) -> torch.Tensor:\n        if use_conv3d:\n            return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)\n        else:\n            return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)\n\n    def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:\n        # First convolution\n        x = F.conv3d(\n            x,\n            self.weight1,\n            self.bias1,\n            self.stride1,\n            self.padding1,\n            self.dilation1,\n            self.groups,\n            padding_mode=self.padding_mode,\n        )\n\n        if skip_time_conv:\n            return x\n\n        # Second convolution\n        x = F.conv3d(\n            x,\n            self.weight2,\n            self.bias2,\n            self.stride2,\n            self.padding2,\n            self.dilation2,\n            self.groups,\n            padding_mode=self.padding_mode,\n        )\n\n        return x\n\n    def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:\n        b, _, _, h, w = x.shape\n\n        # First 2D convolution\n        x = rearrange(x, \"b c d h w -> (b d) c h w\")\n        # Squeeze the depth dimension out of weight1 since it's 1\n        weight1 = self.weight1.squeeze(2)\n        # Select stride, padding, and dilation for the 2D convolution\n        stride1 = (self.stride1[1], self.stride1[2])\n        padding1 = (self.padding1[1], self.padding1[2])\n        dilation1 = (self.dilation1[1], self.dilation1[2])\n        x = F.conv2d(\n            x,\n            weight1,\n            self.bias1,\n            stride1,\n            padding1,\n            dilation1,\n            self.groups,\n            padding_mode=self.padding_mode,\n        )\n\n        _, _, h, w = x.shape\n\n        if skip_time_conv:\n            x = rearrange(x, \"(b d) c h w -> b c d h w\", b=b)\n            return x\n\n        # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension\n        x = rearrange(x, \"(b d) c h w -> (b h w) c d\", b=b)\n\n        # Reshape weight2 to match the expected dimensions for conv1d\n        weight2 = self.weight2.squeeze(-1).squeeze(-1)\n        # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution\n        stride2 = self.stride2[0]\n        padding2 = self.padding2[0]\n        dilation2 = self.dilation2[0]\n        x = F.conv1d(\n            x,\n            weight2,\n            self.bias2,\n            stride2,\n            padding2,\n            dilation2,\n            self.groups,\n            padding_mode=self.padding_mode,\n        )\n        x = rearrange(x, \"(b h w) c d -> b c d h w\", b=b, h=h, w=w)\n\n        return x\n\n    @property\n    def weight(self) -> torch.Tensor:\n        return self.weight2\n\n\nclass CausalConv3d(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int = 3,\n        stride: Union[int, Tuple[int]] = 1,\n        dilation: int = 1,\n        groups: int = 1,\n        bias: bool = True,\n        spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    ) -> None:\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        kernel_size = (kernel_size, kernel_size, kernel_size)\n        self.time_kernel_size = kernel_size[0]\n\n        dilation = (dilation, 1, 1)\n\n        height_pad = kernel_size[1] // 2\n        width_pad = kernel_size[2] // 2\n        padding = (0, height_pad, width_pad)\n\n        self.conv = nn.Conv3d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            dilation=dilation,\n            padding=padding,\n            padding_mode=spatial_padding_mode.value,\n            groups=groups,\n            bias=bias,\n        )\n\n    def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:\n        if causal:\n            first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))\n            x = torch.concatenate((first_frame_pad, x), dim=2)\n        else:\n            first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))\n            last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))\n            x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)\n        x = self.conv(x)\n        return x\n\n    @property\n    def weight(self) -> torch.Tensor:\n        return self.conv.weight\n\n\ndef make_conv_nd(  # noqa: PLR0913\n    dims: Union[int, Tuple[int, int]],\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int,\n    stride: int = 1,\n    padding: int = 0,\n    dilation: int = 1,\n    groups: int = 1,\n    bias: bool = True,\n    causal: bool = False,\n    spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n) -> nn.Module:\n    if not (spatial_padding_mode == temporal_padding_mode or causal):\n        raise NotImplementedError(\"spatial and temporal padding modes must be equal\")\n    if dims == 2:\n        return nn.Conv2d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n            padding_mode=spatial_padding_mode.value,\n        )\n    elif dims == 3:\n        if causal:\n            return CausalConv3d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n                spatial_padding_mode=spatial_padding_mode,\n            )\n        return nn.Conv3d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n            padding_mode=spatial_padding_mode.value,\n        )\n    elif dims == (2, 1):\n        return DualConv3d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            bias=bias,\n            padding_mode=spatial_padding_mode.value,\n        )\n    else:\n        raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef make_linear_nd(\n    dims: int,\n    in_channels: int,\n    out_channels: int,\n    bias: bool = True,\n) -> nn.Module:\n    if dims == 2:\n        return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)\n    elif dims in (3, (2, 1)):\n        return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)\n    else:\n        raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor:\n    \"\"\"\n    Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks\n    and moves pixels from each block into separate channels (space-to-depth).\n    Args:\n        x: Input tensor (4D or 5D)\n        patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks.\n        patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching).\n    For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw)\n    Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128)\n    \"\"\"\n    if patch_size_hw == 1 and patch_size_t == 1:\n        return x\n    if x.dim() == 4:\n        x = rearrange(x, \"b c (h q) (w r) -> b (c r q) h w\", q=patch_size_hw, r=patch_size_hw)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b c (f p) (h q) (w r) -> b (c p r q) f h w\",\n            p=patch_size_t,\n            q=patch_size_hw,\n            r=patch_size_hw,\n        )\n    else:\n        raise ValueError(f\"Invalid input shape: {x.shape}\")\n\n    return x\n\n\ndef unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor:\n    \"\"\"\n    Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from\n    channels back into patch_size x patch_size blocks (depth-to-space).\n    Args:\n        x: Input tensor (4D or 5D)\n        patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x.\n        patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion).\n    For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw)\n    Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512)\n    \"\"\"\n    if patch_size_hw == 1 and patch_size_t == 1:\n        return x\n\n    if x.dim() == 4:\n        x = rearrange(x, \"b (c r q) h w -> b c (h q) (w r)\", q=patch_size_hw, r=patch_size_hw)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b (c p r q) f h w -> b c (f p) (h q) (w r)\",\n            p=patch_size_t,\n            q=patch_size_hw,\n            r=patch_size_hw,\n        )\n\n    return x\n\n\nclass PerChannelStatistics(nn.Module):\n    \"\"\"\n    Per-channel statistics for normalizing and denormalizing the latent representation.\n    This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict.\n    \"\"\"\n\n    def __init__(self, latent_channels: int = 128):\n        super().__init__()\n        self.register_buffer(\"std-of-means\", torch.empty(latent_channels))\n        self.register_buffer(\"mean-of-means\", torch.empty(latent_channels))\n\n    def un_normalize(self, x: torch.Tensor) -> torch.Tensor:\n        return (x * self.get_buffer(\"std-of-means\").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer(\"mean-of-means\").view(\n            1, -1, 1, 1, 1).to(x)\n\n    def normalize(self, x: torch.Tensor) -> torch.Tensor:\n        return (x - self.get_buffer(\"mean-of-means\").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer(\"std-of-means\").view(\n            1, -1, 1, 1, 1).to(x)\n\n\nclass ResnetBlock3D(nn.Module):\n    r\"\"\"\n    A Resnet block.\n    Parameters:\n        in_channels (`int`): The number of channels in the input.\n        out_channels (`int`, *optional*, default to be `None`):\n            The number of output channels for the first conv layer. If None, same as `in_channels`.\n        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.\n        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.\n        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: Union[int, Tuple[int, int]],\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        dropout: float = 0.0,\n        groups: int = 32,\n        eps: float = 1e-6,\n        norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,\n        inject_noise: bool = False,\n        timestep_conditioning: bool = False,\n        spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.inject_noise = inject_noise\n\n        if norm_layer == NormLayerType.GROUP_NORM:\n            self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n        elif norm_layer == NormLayerType.PIXEL_NORM:\n            self.norm1 = PixelNorm()\n\n        self.non_linearity = nn.SiLU()\n\n        self.conv1 = make_conv_nd(\n            dims,\n            in_channels,\n            out_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        if inject_noise:\n            self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))\n\n        if norm_layer == NormLayerType.GROUP_NORM:\n            self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)\n        elif norm_layer == NormLayerType.PIXEL_NORM:\n            self.norm2 = PixelNorm()\n\n        self.dropout = torch.nn.Dropout(dropout)\n\n        self.conv2 = make_conv_nd(\n            dims,\n            out_channels,\n            out_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        if inject_noise:\n            self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))\n\n        self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)\n                              if in_channels != out_channels else nn.Identity())\n\n        # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout\n        # avoiding the need for dimension rearrangement used in standard nn.LayerNorm\n        self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True)\n                      if in_channels != out_channels else nn.Identity())\n\n        self.timestep_conditioning = timestep_conditioning\n\n        if timestep_conditioning:\n            self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels))\n\n    def _feed_spatial_noise(\n        self,\n        hidden_states: torch.Tensor,\n        per_channel_scale: torch.Tensor,\n        generator: Optional[torch.Generator] = None,\n    ) -> torch.Tensor:\n        spatial_shape = hidden_states.shape[-2:]\n        device = hidden_states.device\n        dtype = hidden_states.dtype\n\n        # similar to the \"explicit noise inputs\" method in style-gan\n        spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None]\n        scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]\n        hidden_states = hidden_states + scaled_noise\n\n        return hidden_states\n\n    def forward(\n        self,\n        input_tensor: torch.Tensor,\n        causal: bool = True,\n        timestep: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n    ) -> torch.Tensor:\n        hidden_states = input_tensor\n        batch_size = hidden_states.shape[0]\n\n        hidden_states = self.norm1(hidden_states)\n        if self.timestep_conditioning:\n            if timestep is None:\n                raise ValueError(\"'timestep' parameter must be provided when 'timestep_conditioning' is True\")\n            ada_values = self.scale_shift_table[None, ..., None, None, None].to(\n                device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(\n                    batch_size,\n                    4,\n                    -1,\n                    timestep.shape[-3],\n                    timestep.shape[-2],\n                    timestep.shape[-1],\n                )\n            shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)\n\n            hidden_states = hidden_states * (1 + scale1) + shift1\n\n        hidden_states = self.non_linearity(hidden_states)\n\n        hidden_states = self.conv1(hidden_states, causal=causal)\n\n        if self.inject_noise:\n            hidden_states = self._feed_spatial_noise(\n                hidden_states,\n                self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype),\n                generator=generator,\n            )\n\n        hidden_states = self.norm2(hidden_states)\n\n        if self.timestep_conditioning:\n            hidden_states = hidden_states * (1 + scale2) + shift2\n\n        hidden_states = self.non_linearity(hidden_states)\n\n        hidden_states = self.dropout(hidden_states)\n\n        hidden_states = self.conv2(hidden_states, causal=causal)\n\n        if self.inject_noise:\n            hidden_states = self._feed_spatial_noise(\n                hidden_states,\n                self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype),\n                generator=generator,\n            )\n\n        input_tensor = self.norm3(input_tensor)\n\n        batch_size = input_tensor.shape[0]\n\n        input_tensor = self.conv_shortcut(input_tensor)\n\n        output_tensor = input_tensor + hidden_states\n\n        return output_tensor\n\n\nclass UNetMidBlock3D(nn.Module):\n    \"\"\"\n    A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.\n    Args:\n        in_channels (`int`): The number of input channels.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.\n        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.\n        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.\n        resnet_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use in the group normalization layers of the resnet blocks.\n        norm_layer (`str`, *optional*, defaults to `group_norm`):\n            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.\n        inject_noise (`bool`, *optional*, defaults to `False`):\n            Whether to inject noise into the hidden states.\n        timestep_conditioning (`bool`, *optional*, defaults to `False`):\n            Whether to condition the hidden states on the timestep.\n    Returns:\n        `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size,\n        in_channels, height, width)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: Union[int, Tuple[int, int]],\n        in_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_groups: int = 32,\n        norm_layer: NormLayerType = NormLayerType.GROUP_NORM,\n        inject_noise: bool = False,\n        timestep_conditioning: bool = False,\n        spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    ):\n        super().__init__()\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        self.timestep_conditioning = timestep_conditioning\n\n        if timestep_conditioning:\n            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4,\n                                                                           size_emb_dim=0)\n\n        self.res_blocks = nn.ModuleList([\n            ResnetBlock3D(\n                dims=dims,\n                in_channels=in_channels,\n                out_channels=in_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                norm_layer=norm_layer,\n                inject_noise=inject_noise,\n                timestep_conditioning=timestep_conditioning,\n                spatial_padding_mode=spatial_padding_mode,\n            ) for _ in range(num_layers)\n        ])\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        causal: bool = True,\n        timestep: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n    ) -> torch.Tensor:\n        timestep_embed = None\n        if self.timestep_conditioning:\n            if timestep is None:\n                raise ValueError(\"'timestep' parameter must be provided when 'timestep_conditioning' is True\")\n            batch_size = hidden_states.shape[0]\n            timestep_embed = self.time_embedder(\n                timestep=timestep.flatten(),\n                hidden_dtype=hidden_states.dtype,\n            )\n            timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1)\n\n        for resnet in self.res_blocks:\n            hidden_states = resnet(\n                hidden_states,\n                causal=causal,\n                timestep=timestep_embed,\n                generator=generator,\n            )\n\n        return hidden_states\n\n\nclass SpaceToDepthDownsample(nn.Module):\n\n    def __init__(\n        self,\n        dims: Union[int, Tuple[int, int]],\n        in_channels: int,\n        out_channels: int,\n        stride: Tuple[int, int, int],\n        spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    ):\n        super().__init__()\n        self.stride = stride\n        self.group_size = in_channels * math.prod(stride) // out_channels\n        self.conv = make_conv_nd(\n            dims=dims,\n            in_channels=in_channels,\n            out_channels=out_channels // math.prod(stride),\n            kernel_size=3,\n            stride=1,\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        if self.stride[0] == 2:\n            x = torch.cat([x[:, :, :1, :, :], x], dim=2)  # duplicate first frames for padding\n\n        # skip connection\n        x_in = rearrange(\n            x,\n            \"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w\",\n            p1=self.stride[0],\n            p2=self.stride[1],\n            p3=self.stride[2],\n        )\n        x_in = rearrange(x_in, \"b (c g) d h w -> b c g d h w\", g=self.group_size)\n        x_in = x_in.mean(dim=2)\n\n        # conv\n        x = self.conv(x, causal=causal)\n        x = rearrange(\n            x,\n            \"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w\",\n            p1=self.stride[0],\n            p2=self.stride[1],\n            p3=self.stride[2],\n        )\n\n        x = x + x_in\n\n        return x\n\n\nclass DepthToSpaceUpsample(nn.Module):\n\n    def __init__(\n        self,\n        dims: int | Tuple[int, int],\n        in_channels: int,\n        stride: Tuple[int, int, int],\n        residual: bool = False,\n        out_channels_reduction_factor: int = 1,\n        spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n    ):\n        super().__init__()\n        self.stride = stride\n        self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor\n        self.conv = make_conv_nd(\n            dims=dims,\n            in_channels=in_channels,\n            out_channels=self.out_channels,\n            kernel_size=3,\n            stride=1,\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n        self.residual = residual\n        self.out_channels_reduction_factor = out_channels_reduction_factor\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        if self.residual:\n            # Reshape and duplicate the input to match the output shape\n            x_in = rearrange(\n                x,\n                \"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)\",\n                p1=self.stride[0],\n                p2=self.stride[1],\n                p3=self.stride[2],\n            )\n            num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor\n            x_in = x_in.repeat(1, num_repeat, 1, 1, 1)\n            if self.stride[0] == 2:\n                x_in = x_in[:, :, 1:, :, :]\n        x = self.conv(x, causal=causal)\n        x = rearrange(\n            x,\n            \"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)\",\n            p1=self.stride[0],\n            p2=self.stride[1],\n            p3=self.stride[2],\n        )\n        if self.stride[0] == 2:\n            x = x[:, :, 1:, :, :]\n        if self.residual:\n            x = x + x_in\n        return x\n\n\ndef compute_trapezoidal_mask_1d(\n    length: int,\n    ramp_left: int,\n    ramp_right: int,\n    left_starts_from_0: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Generate a 1D trapezoidal blending mask with linear ramps.\n    Args:\n        length: Output length of the mask.\n        ramp_left: Fade-in length on the left.\n        ramp_right: Fade-out length on the right.\n        left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.\n            Useful for temporal tiles where the first tile is causal.\n    Returns:\n        A 1D tensor of shape `(length,)` with values in [0, 1].\n    \"\"\"\n    if length <= 0:\n        raise ValueError(\"Mask length must be positive.\")\n\n    ramp_left = max(0, min(ramp_left, length))\n    ramp_right = max(0, min(ramp_right, length))\n\n    mask = torch.ones(length)\n\n    if ramp_left > 0:\n        interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2\n        fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1]\n        if not left_starts_from_0:\n            fade_in = fade_in[1:]\n        mask[:ramp_left] *= fade_in\n\n    if ramp_right > 0:\n        fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1]\n        mask[-ramp_right:] *= fade_out\n\n    return mask.clamp_(0, 1)\n\n\n@dataclass(frozen=True)\nclass SpatialTilingConfig:\n    \"\"\"Configuration for dividing each frame into spatial tiles with optional overlap.\n    Args:\n        tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32.\n        tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0.\n    \"\"\"\n\n    tile_size_in_pixels: int\n    tile_overlap_in_pixels: int = 0\n\n    def __post_init__(self) -> None:\n        if self.tile_size_in_pixels < 64:\n            raise ValueError(f\"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}\")\n        if self.tile_size_in_pixels % 32 != 0:\n            raise ValueError(f\"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}\")\n        if self.tile_overlap_in_pixels % 32 != 0:\n            raise ValueError(f\"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}\")\n        if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:\n            raise ValueError(\n                f\"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}\"\n            )\n\n\n@dataclass(frozen=True)\nclass TemporalTilingConfig:\n    \"\"\"Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap.\n    Args:\n        tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8.\n        tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles.\n            Must be divisible by 8. Defaults to 0.\n    \"\"\"\n\n    tile_size_in_frames: int\n    tile_overlap_in_frames: int = 0\n\n    def __post_init__(self) -> None:\n        if self.tile_size_in_frames < 16:\n            raise ValueError(f\"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}\")\n        if self.tile_size_in_frames % 8 != 0:\n            raise ValueError(f\"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}\")\n        if self.tile_overlap_in_frames % 8 != 0:\n            raise ValueError(f\"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}\")\n        if self.tile_overlap_in_frames >= self.tile_size_in_frames:\n            raise ValueError(\n                f\"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}\"\n            )\n\n\n@dataclass(frozen=True)\nclass TilingConfig:\n    \"\"\"Configuration for splitting video into tiles with optional overlap.\n    Attributes:\n        spatial_config: Configuration for splitting spatial dimensions into tiles.\n        temporal_config: Configuration for splitting temporal dimension into tiles.\n    \"\"\"\n\n    spatial_config: SpatialTilingConfig | None = None\n    temporal_config: TemporalTilingConfig | None = None\n\n    @classmethod\n    def default(cls) -> \"TilingConfig\":\n        return cls(\n            spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),\n            temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),\n        )\n\n\n@dataclass(frozen=True)\nclass DimensionIntervals:\n    \"\"\"Intervals which a single dimension of the latent space is split into.\n    Each interval is defined by its start, end, left ramp, and right ramp.\n    The start and end are the indices of the first and last element (exclusive) in the interval.\n    Ramps are regions of the interval where the value of the mask tensor is\n    interpolated between 0 and 1 for blending with neighboring intervals.\n    The left ramp and right ramp values are the lengths of the left and right ramps.\n    \"\"\"\n\n    starts: List[int]\n    ends: List[int]\n    left_ramps: List[int]\n    right_ramps: List[int]\n\n\n@dataclass(frozen=True)\nclass LatentIntervals:\n    \"\"\"Intervals which the latent tensor of given shape is split into.\n    Each dimension of the latent space is split into intervals based on the length along said dimension.\n    \"\"\"\n\n    original_shape: torch.Size\n    dimension_intervals: Tuple[DimensionIntervals, ...]\n\n\n# Operation to split a single dimension of the tensor into intervals based on the length along the dimension.\nSplitOperation = Callable[[int], DimensionIntervals]\n# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension.\nMappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]]\n\n\ndef default_split_operation(length: int) -> DimensionIntervals:\n    return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0])\n\n\nDEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation\n\n\ndef default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]:\n    return [slice(0, None)], [None]\n\n\nDEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation\n\n\nclass Tile(NamedTuple):\n    \"\"\"\n    Represents a single tile.\n    Attributes:\n        in_coords:\n            Tuple of slices specifying where to cut the tile from the INPUT tensor.\n        out_coords:\n            Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor.\n        masks_1d:\n            Per-dimension masks in OUTPUT units.\n            These are used to create all-dimensional blending mask.\n    Methods:\n        blend_mask:\n            Create a single N-D mask from the per-dimension masks.\n    \"\"\"\n\n    in_coords: Tuple[slice, ...]\n    out_coords: Tuple[slice, ...]\n    masks_1d: Tuple[Tuple[torch.Tensor, ...]]\n\n    @property\n    def blend_mask(self) -> torch.Tensor:\n        num_dims = len(self.out_coords)\n        per_dimension_masks: List[torch.Tensor] = []\n\n        for dim_idx in range(num_dims):\n            mask_1d = self.masks_1d[dim_idx]\n            view_shape = [1] * num_dims\n            if mask_1d is None:\n                # Broadcast mask along this dimension (length 1).\n                one = torch.ones(1)\n\n                view_shape[dim_idx] = 1\n                per_dimension_masks.append(one.view(*view_shape))\n                continue\n\n            # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply.\n            view_shape[dim_idx] = mask_1d.shape[0]\n            per_dimension_masks.append(mask_1d.view(*view_shape))\n\n        # Multiply per-dimension masks to form the full N-D mask (separable blending window).\n        combined_mask = per_dimension_masks[0]\n        for mask in per_dimension_masks[1:]:\n            combined_mask = combined_mask * mask\n\n        return combined_mask\n\n\ndef create_tiles_from_intervals_and_mappers(\n    intervals: LatentIntervals,\n    mappers: List[MappingOperation],\n) -> List[Tile]:\n    full_dim_input_slices = []\n    full_dim_output_slices = []\n    full_dim_masks_1d = []\n    for axis_index in range(len(intervals.original_shape)):\n        dimension_intervals = intervals.dimension_intervals[axis_index]\n        starts = dimension_intervals.starts\n        ends = dimension_intervals.ends\n        input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)]\n        output_slices, masks_1d = mappers[axis_index](dimension_intervals)\n        full_dim_input_slices.append(input_slices)\n        full_dim_output_slices.append(output_slices)\n        full_dim_masks_1d.append(masks_1d)\n\n    tiles = []\n    tile_in_coords = list(itertools.product(*full_dim_input_slices))\n    tile_out_coords = list(itertools.product(*full_dim_output_slices))\n    tile_mask_1ds = list(itertools.product(*full_dim_masks_1d))\n    for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True):\n        tiles.append(Tile(\n            in_coords=in_coord,\n            out_coords=out_coord,\n            masks_1d=mask_1d,\n        ))\n    return tiles\n\n\ndef create_tiles(\n    latent_shape: torch.Size,\n    splitters: List[SplitOperation],\n    mappers: List[MappingOperation],\n) -> List[Tile]:\n    if len(splitters) != len(latent_shape):\n        raise ValueError(f\"Number of splitters must be equal to number of dimensions in latent shape, \"\n                         f\"got {len(splitters)} and {len(latent_shape)}\")\n    if len(mappers) != len(latent_shape):\n        raise ValueError(f\"Number of mappers must be equal to number of dimensions in latent shape, \"\n                         f\"got {len(mappers)} and {len(latent_shape)}\")\n    intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)]\n    latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals))\n    return create_tiles_from_intervals_and_mappers(latent_intervals, mappers)\n\n\ndef _make_encoder_block(\n    block_name: str,\n    block_config: dict[str, Any],\n    in_channels: int,\n    convolution_dimensions: int,\n    norm_layer: NormLayerType,\n    norm_num_groups: int,\n    spatial_padding_mode: PaddingModeType,\n) -> Tuple[nn.Module, int]:\n    out_channels = in_channels\n\n    if block_name == \"res_x\":\n        block = UNetMidBlock3D(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            num_layers=block_config[\"num_layers\"],\n            resnet_eps=1e-6,\n            resnet_groups=norm_num_groups,\n            norm_layer=norm_layer,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"res_x_y\":\n        out_channels = in_channels * block_config.get(\"multiplier\", 2)\n        block = ResnetBlock3D(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            eps=1e-6,\n            groups=norm_num_groups,\n            norm_layer=norm_layer,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_time\":\n        block = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=(2, 1, 1),\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_space\":\n        block = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=(1, 2, 2),\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_all\":\n        block = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=(2, 2, 2),\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_all_x_y\":\n        out_channels = in_channels * block_config.get(\"multiplier\", 2)\n        block = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=(2, 2, 2),\n            causal=True,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_all_res\":\n        out_channels = in_channels * block_config.get(\"multiplier\", 2)\n        block = SpaceToDepthDownsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            stride=(2, 2, 2),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_space_res\":\n        out_channels = in_channels * block_config.get(\"multiplier\", 2)\n        block = SpaceToDepthDownsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            stride=(1, 2, 2),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_time_res\":\n        out_channels = in_channels * block_config.get(\"multiplier\", 2)\n        block = SpaceToDepthDownsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            stride=(2, 1, 1),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    else:\n        raise ValueError(f\"unknown block: {block_name}\")\n\n    return block, out_channels\n\n\nclass LTX2VideoEncoder(nn.Module):\n    _DEFAULT_NORM_NUM_GROUPS = 32\n    \"\"\"\n    Variational Autoencoder Encoder. Encodes video frames into a latent representation.\n    The encoder compresses the input video through a series of downsampling operations controlled by\n    patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W').\n    Compression Behavior:\n        The total compression is determined by:\n        1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4)\n        2. Sequential compression through encoder_blocks based on their stride patterns\n        Compression blocks apply 2x compression in specified dimensions:\n            - \"compress_time\" / \"compress_time_res\": temporal only\n            - \"compress_space\" / \"compress_space_res\": spatial only (H and W)\n            - \"compress_all\" / \"compress_all_res\": all dimensions (F, H, W)\n            - \"res_x\" / \"res_x_y\": no compression\n        Standard LTX Video configuration:\n            - patch_size=4\n            - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res\n            - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32\n            - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16)\n            - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...)\n    Args:\n        convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).\n        in_channels: The number of input channels. For RGB images, this is 3.\n        out_channels: The number of output channels (latent channels). For latent channels, this is 128.\n        encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params)\n                        where params is either an int (num_layers) or a dict with configuration.\n        patch_size: The patch size for initial spatial compression. Should be a power of 2.\n        norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.\n        latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`.\n    \"\"\"\n\n    def __init__(\n        self,\n        convolution_dimensions: int = 3,\n        in_channels: int = 3,\n        out_channels: int = 128,\n        patch_size: int = 4,\n        norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,\n        latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,\n        encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,\n        encoder_version: str = \"ltx-2\",\n    ):\n        super().__init__()\n        if encoder_version == \"ltx-2\":\n            encoder_blocks = [\n                ['res_x', {'num_layers': 4}],\n                ['compress_space_res', {'multiplier': 2}],\n                ['res_x', {'num_layers': 6}],\n                ['compress_time_res', {'multiplier': 2}],\n                ['res_x', {'num_layers': 6}],\n                ['compress_all_res', {'multiplier': 2}],\n                ['res_x', {'num_layers': 2}],\n                ['compress_all_res', {'multiplier': 2}],\n                ['res_x', {'num_layers': 2}]\n            ]\n        else:\n            # LTX-2.3\n            encoder_blocks = [\n                [\"res_x\", {\"num_layers\": 4}],\n                [\"compress_space_res\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 6}],\n                [\"compress_time_res\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 4}],\n                [\"compress_all_res\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 2}],\n                [\"compress_all_res\", {\"multiplier\": 1}],\n                [\"res_x\", {\"num_layers\": 2}]\n            ]\n        self.patch_size = patch_size\n        self.norm_layer = norm_layer\n        self.latent_channels = out_channels\n        self.latent_log_var = latent_log_var\n        self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS\n\n        # Per-channel statistics for normalizing latents\n        self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)\n\n        in_channels = in_channels * patch_size**2\n        feature_channels = out_channels\n\n        self.conv_in = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=feature_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=encoder_spatial_padding_mode,\n        )\n\n        self.down_blocks = nn.ModuleList([])\n\n        for block_name, block_params in encoder_blocks:\n            # Convert int to dict format for uniform handling\n            block_config = {\"num_layers\": block_params} if isinstance(block_params, int) else block_params\n\n            block, feature_channels = _make_encoder_block(\n                block_name=block_name,\n                block_config=block_config,\n                in_channels=feature_channels,\n                convolution_dimensions=convolution_dimensions,\n                norm_layer=norm_layer,\n                norm_num_groups=self._norm_num_groups,\n                spatial_padding_mode=encoder_spatial_padding_mode,\n            )\n\n            self.down_blocks.append(block)\n\n        # out\n        if norm_layer == NormLayerType.GROUP_NORM:\n            self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)\n        elif norm_layer == NormLayerType.PIXEL_NORM:\n            self.conv_norm_out = PixelNorm()\n\n        self.conv_act = nn.SiLU()\n\n        conv_out_channels = out_channels\n        if latent_log_var == LogVarianceType.PER_CHANNEL:\n            conv_out_channels *= 2\n        elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:\n            conv_out_channels += 1\n        elif latent_log_var != LogVarianceType.NONE:\n            raise ValueError(f\"Invalid latent_log_var: {latent_log_var}\")\n\n        self.conv_out = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=feature_channels,\n            out_channels=conv_out_channels,\n            kernel_size=3,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=encoder_spatial_padding_mode,\n        )\n\n    def forward(self, sample: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Encode video frames into normalized latent representation.\n        Args:\n            sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...).\n        Returns:\n            Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32.\n            Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16).\n        \"\"\"\n        # Validate frame count\n        frames_count = sample.shape[2]\n        if ((frames_count - 1) % 8) != 0:\n            frames_to_crop = (frames_count - 1) % 8\n            sample = sample[:, :, :-frames_to_crop, ...]\n\n        # Initial spatial compression: trade spatial resolution for channel depth\n        # This reduces H,W by patch_size and increases channels, making convolutions more efficient\n        # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4\n        sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)\n        sample = self.conv_in(sample)\n\n        for down_block in self.down_blocks:\n            sample = down_block(sample)\n\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if self.latent_log_var == LogVarianceType.UNIFORM:\n            # Uniform Variance: model outputs N means and 1 shared log-variance channel.\n            # We need to expand the single logvar to match the number of means channels\n            # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels).\n            # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129)\n            # Target shape: (B, 2*N, ...) where first N are means, last N are logvar\n\n            if sample.shape[1] < 2:\n                raise ValueError(f\"Invalid channel count for UNIFORM mode: expected at least 2 channels \"\n                                 f\"(N means + 1 logvar), got {sample.shape[1]}\")\n\n            # Extract means (first N channels) and logvar (last 1 channel)\n            means = sample[:, :-1, ...]  # (B, N, ...)\n            logvar = sample[:, -1:, ...]  # (B, 1, ...)\n\n            # Repeat logvar N times to match means channels\n            # Use expand/repeat pattern that works for both 4D and 5D tensors\n            num_channels = means.shape[1]\n            repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2)\n            repeated_logvar = logvar.repeat(*repeat_shape)  # (B, N, ...)\n\n            # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar]\n            sample = torch.cat([means, repeated_logvar], dim=1)\n        elif self.latent_log_var == LogVarianceType.CONSTANT:\n            sample = sample[:, :-1, ...]\n            approx_ln_0 = -30  # this is the minimal clamp value in DiagonalGaussianDistribution objects\n            sample = torch.cat(\n                [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],\n                dim=1,\n            )\n\n        # Split into means and logvar, then normalize means\n        means, _ = torch.chunk(sample, 2, dim=1)\n        return self.per_channel_statistics.normalize(means)\n\n\n    def tiled_encode_video(\n        self,\n        video: torch.Tensor,\n        tile_size: int = 512,\n        tile_overlap: int = 128,\n    ) -> torch.Tensor:\n        \"\"\"Encode video using spatial tiling for memory efficiency.\n        Splits the video into overlapping spatial tiles, encodes each tile separately,\n        and blends the results using linear feathering in the overlap regions.\n        Args:\n            video: Input tensor of shape [B, C, F, H, W]\n            tile_size: Tile size in pixels (must be divisible by 32)\n            tile_overlap: Overlap between tiles in pixels (must be divisible by 32)\n        Returns:\n            Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]\n        \"\"\"\n        batch, _channels, frames, height, width = video.shape\n        device = video.device\n        dtype = video.dtype\n\n        # Validate tile parameters\n        if tile_size % VAE_SPATIAL_FACTOR != 0:\n            raise ValueError(f\"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}\")\n        if tile_overlap % VAE_SPATIAL_FACTOR != 0:\n            raise ValueError(f\"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}\")\n        if tile_overlap >= tile_size:\n            raise ValueError(f\"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})\")\n\n        # If video fits in a single tile, use regular encoding\n        if height <= tile_size and width <= tile_size:\n            return self.forward(video)\n\n        # Calculate output dimensions\n        # VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8\n        output_height = height // VAE_SPATIAL_FACTOR\n        output_width = width // VAE_SPATIAL_FACTOR\n        output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR\n\n        # Latent channels (128 for LTX-2)\n        # Get from a small test encode or assume 128\n        latent_channels = 128\n\n        # Initialize output and weight tensors\n        output = torch.zeros(\n            (batch, latent_channels, output_frames, output_height, output_width),\n            device=device,\n            dtype=dtype,\n        )\n        weights = torch.zeros(\n            (batch, 1, output_frames, output_height, output_width),\n            device=device,\n            dtype=dtype,\n        )\n\n        # Calculate tile positions with overlap\n        # Step size is tile_size - tile_overlap\n        step_h = tile_size - tile_overlap\n        step_w = tile_size - tile_overlap\n\n        h_positions = list(range(0, max(1, height - tile_overlap), step_h))\n        w_positions = list(range(0, max(1, width - tile_overlap), step_w))\n\n        # Ensure last tile covers the edge\n        if h_positions[-1] + tile_size < height:\n            h_positions.append(height - tile_size)\n        if w_positions[-1] + tile_size < width:\n            w_positions.append(width - tile_size)\n\n        # Remove duplicates and sort\n        h_positions = sorted(set(h_positions))\n        w_positions = sorted(set(w_positions))\n\n        # Overlap in latent space\n        overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR\n        overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR\n\n        # Process each tile\n        for h_pos in h_positions:\n            for w_pos in w_positions:\n                # Calculate tile boundaries in input space\n                h_start = max(0, h_pos)\n                w_start = max(0, w_pos)\n                h_end = min(h_start + tile_size, height)\n                w_end = min(w_start + tile_size, width)\n\n                # Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR\n                tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR\n                tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR\n\n                if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:\n                    continue\n\n                # Adjust end positions\n                h_end = h_start + tile_h\n                w_end = w_start + tile_w\n\n                # Extract tile\n                tile = video[:, :, :, h_start:h_end, w_start:w_end]\n\n                # Encode tile\n                encoded_tile = self.forward(tile)\n\n                # Get actual encoded dimensions\n                _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape\n\n                # Calculate output positions\n                out_h_start = h_start // VAE_SPATIAL_FACTOR\n                out_w_start = w_start // VAE_SPATIAL_FACTOR\n                out_h_end = min(out_h_start + tile_out_height, output_height)\n                out_w_end = min(out_w_start + tile_out_width, output_width)\n\n                # Trim encoded tile if necessary\n                actual_tile_h = out_h_end - out_h_start\n                actual_tile_w = out_w_end - out_w_start\n                encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]\n\n                # Create blending mask with linear feathering at edges\n                mask = torch.ones(\n                    (1, 1, tile_out_frames, actual_tile_h, actual_tile_w),\n                    device=device,\n                    dtype=dtype,\n                )\n\n                # Apply feathering at edges (linear blend in overlap regions)\n                # Left edge\n                if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:\n                    fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]\n                    mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)\n\n                # Right edge (bottom in height dimension)\n                if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:\n                    fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]\n                    mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)\n\n                # Top edge (left in width dimension)\n                if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:\n                    fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]\n                    mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)\n\n                # Bottom edge (right in width dimension)\n                if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:\n                    fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]\n                    mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)\n\n                # Accumulate weighted results\n                output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask\n                weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask\n\n        # Normalize by weights (avoid division by zero)\n        output = output / (weights + 1e-8)\n\n        return output\n\n    def encode(\n        self,\n        video: torch.Tensor,\n        tiled=False,\n        tile_size_in_pixels: Optional[int] = 512,\n        tile_overlap_in_pixels: Optional[int] = 128,\n        **kwargs,\n    ) -> torch.Tensor:\n        if video.ndim == 4:\n            video = video.unsqueeze(0)  # [C, F, H, W] -> [B, C, F, H, W]\n        # Choose encoding method based on tiling flag\n        if tiled:\n            latents = self.tiled_encode_video(\n                video=video,\n                tile_size=tile_size_in_pixels,\n                tile_overlap=tile_overlap_in_pixels,\n            )\n        else:\n            # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']\n            latents = self.forward(video)\n        return latents\n\n\ndef _make_decoder_block(\n    block_name: str,\n    block_config: dict[str, Any],\n    in_channels: int,\n    convolution_dimensions: int,\n    norm_layer: NormLayerType,\n    timestep_conditioning: bool,\n    norm_num_groups: int,\n    spatial_padding_mode: PaddingModeType,\n) -> Tuple[nn.Module, int]:\n    out_channels = in_channels\n    if block_name == \"res_x\":\n        block = UNetMidBlock3D(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            num_layers=block_config[\"num_layers\"],\n            resnet_eps=1e-6,\n            resnet_groups=norm_num_groups,\n            norm_layer=norm_layer,\n            inject_noise=block_config.get(\"inject_noise\", False),\n            timestep_conditioning=timestep_conditioning,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"attn_res_x\":\n        block = UNetMidBlock3D(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            num_layers=block_config[\"num_layers\"],\n            resnet_groups=norm_num_groups,\n            norm_layer=norm_layer,\n            inject_noise=block_config.get(\"inject_noise\", False),\n            timestep_conditioning=timestep_conditioning,\n            attention_head_dim=block_config[\"attention_head_dim\"],\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"res_x_y\":\n        out_channels = in_channels // block_config.get(\"multiplier\", 2)\n        block = ResnetBlock3D(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            eps=1e-6,\n            groups=norm_num_groups,\n            norm_layer=norm_layer,\n            inject_noise=block_config.get(\"inject_noise\", False),\n            timestep_conditioning=False,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_time\":\n        out_channels = in_channels // block_config.get(\"multiplier\", 1)\n        block = DepthToSpaceUpsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            stride=(2, 1, 1),\n            out_channels_reduction_factor=block_config.get(\"multiplier\", 1),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_space\":\n        out_channels = in_channels // block_config.get(\"multiplier\", 1)\n        block = DepthToSpaceUpsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            stride=(1, 2, 2),\n            out_channels_reduction_factor=block_config.get(\"multiplier\", 1),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    elif block_name == \"compress_all\":\n        out_channels = in_channels // block_config.get(\"multiplier\", 1)\n        block = DepthToSpaceUpsample(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            stride=(2, 2, 2),\n            residual=block_config.get(\"residual\", False),\n            out_channels_reduction_factor=block_config.get(\"multiplier\", 1),\n            spatial_padding_mode=spatial_padding_mode,\n        )\n    else:\n        raise ValueError(f\"unknown layer: {block_name}\")\n\n    return block, out_channels\n\n\nclass LTX2VideoDecoder(nn.Module):\n    _DEFAULT_NORM_NUM_GROUPS = 32\n    \"\"\"\n    Variational Autoencoder Decoder. Decodes latent representation into video frames.\n    The decoder upsamples latents through a series of upsampling operations (inverse of encoder).\n    Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration.\n    Upsampling blocks expand dimensions by 2x in specified dimensions:\n        - \"compress_time\": temporal only\n        - \"compress_space\": spatial only (H and W)\n        - \"compress_all\": all dimensions (F, H, W)\n        - \"res_x\" / \"res_x_y\" / \"attn_res_x\": no upsampling\n    Causal Mode:\n        causal=False (standard): Symmetric padding, allows future frame dependencies.\n        causal=True: Causal padding, each frame depends only on past/current frames.\n        First frame removed after temporal upsampling in both modes. Output shape unchanged.\n        Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes.\n    Args:\n        convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).\n        in_channels: The number of input channels (latent channels). Default is 128.\n        out_channels: The number of output channels. For RGB images, this is 3.\n        decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params)\n                        where params is either an int (num_layers) or a dict with configuration.\n        patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion:\n                    H -> Hx4, W -> Wx4. Should be a power of 2.\n        norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.\n        causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding.\n                When True, uses causal padding (past/current frames only).\n        timestep_conditioning: Whether to condition the decoder on timestep for denoising.\n    \"\"\"\n\n    def __init__(\n        self,\n        convolution_dimensions: int = 3,\n        in_channels: int = 128,\n        out_channels: int = 3,\n        decoder_blocks: List[Tuple[str, int | dict]] = [],  # noqa: B006\n        patch_size: int = 4,\n        norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,\n        causal: bool = False,\n        timestep_conditioning: bool = False,\n        decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,\n        decoder_version: str = \"ltx-2\",\n        base_channels: int = 128,\n    ):\n        super().__init__()\n\n        # Spatiotemporal downscaling between decoded video space and VAE latents.\n        # According to the LTXV paper, the standard configuration downsamples\n        # video inputs by a factor of 8 in the temporal dimension and 32 in\n        # each spatial dimension (height and width). This parameter determines how\n        # many video frames and pixels correspond to a single latent cell.\n        if decoder_version == \"ltx-2\":\n            decoder_blocks = [\n                ['res_x', {'num_layers': 5, 'inject_noise': False}],\n                ['compress_all', {'residual': True, 'multiplier': 2}],\n                ['res_x', {'num_layers': 5, 'inject_noise': False}],\n                ['compress_all', {'residual': True, 'multiplier': 2}],\n                ['res_x', {'num_layers': 5, 'inject_noise': False}],\n                ['compress_all', {'residual': True, 'multiplier': 2}],\n                ['res_x', {'num_layers': 5, 'inject_noise': False}]\n            ]\n        else:\n            # LTX-2.3\n            decoder_blocks = [\n                [\"res_x\", {\"num_layers\": 4}],\n                [\"compress_space\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 6}],\n                [\"compress_time\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 4}],\n                [\"compress_all\", {\"multiplier\": 1}],\n                [\"res_x\", {\"num_layers\": 2}],\n                [\"compress_all\", {\"multiplier\": 2}],\n                [\"res_x\", {\"num_layers\": 2}]\n            ]\n        self.video_downscale_factors = SpatioTemporalScaleFactors(\n            time=8,\n            width=32,\n            height=32,\n        )\n\n        self.patch_size = patch_size\n        out_channels = out_channels * patch_size**2\n        self.causal = causal\n        self.timestep_conditioning = timestep_conditioning\n        self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS\n\n        # Per-channel statistics for denormalizing latents\n        self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)\n\n        # Noise and timestep parameters for decoder conditioning\n        self.decode_noise_scale = 0.025\n        self.decode_timestep = 0.05\n\n        # LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.\n        # Hence the total feature_channels is multiplied by 8 (2^3).\n        feature_channels = base_channels * 8\n\n        self.conv_in = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=in_channels,\n            out_channels=feature_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=decoder_spatial_padding_mode,\n        )\n\n        self.up_blocks = nn.ModuleList([])\n\n        for block_name, block_params in list(reversed(decoder_blocks)):\n            # Convert int to dict format for uniform handling\n            block_config = {\"num_layers\": block_params} if isinstance(block_params, int) else block_params\n\n            block, feature_channels = _make_decoder_block(\n                block_name=block_name,\n                block_config=block_config,\n                in_channels=feature_channels,\n                convolution_dimensions=convolution_dimensions,\n                norm_layer=norm_layer,\n                timestep_conditioning=timestep_conditioning,\n                norm_num_groups=self._norm_num_groups,\n                spatial_padding_mode=decoder_spatial_padding_mode,\n            )\n\n            self.up_blocks.append(block)\n\n        if norm_layer == NormLayerType.GROUP_NORM:\n            self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)\n        elif norm_layer == NormLayerType.PIXEL_NORM:\n            self.conv_norm_out = PixelNorm()\n\n        self.conv_act = nn.SiLU()\n        self.conv_out = make_conv_nd(\n            dims=convolution_dimensions,\n            in_channels=feature_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            padding=1,\n            causal=True,\n            spatial_padding_mode=decoder_spatial_padding_mode,\n        )\n\n        if timestep_conditioning:\n            self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0))\n            self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2,\n                                                                                size_emb_dim=0)\n            self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels))\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: torch.Tensor | None = None,\n        generator: torch.Generator | None = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Decode latent representation into video frames.\n        Args:\n            sample: Latent tensor (B, 128, F', H', W').\n            timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None.\n            generator: Random generator for deterministic noise injection (if inject_noise=True in blocks).\n        Returns:\n            Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'.\n            Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512).\n            Note: First frame is removed after temporal upsampling regardless of causal mode.\n            When causal=False, allows future frame dependencies in convolutions but maintains same output shape.\n        \"\"\"\n        batch_size = sample.shape[0]\n\n        # Add noise if timestep conditioning is enabled\n        if self.timestep_conditioning:\n            noise = (torch.randn(\n                sample.size(),\n                generator=generator,\n                dtype=sample.dtype,\n                device=sample.device,\n            ) * self.decode_noise_scale)\n\n            sample = noise + (1.0 - self.decode_noise_scale) * sample\n\n        # Denormalize latents\n        sample = self.per_channel_statistics.un_normalize(sample)\n\n        # Use default decode_timestep if timestep not provided\n        if timestep is None and self.timestep_conditioning:\n            timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype)\n\n        sample = self.conv_in(sample, causal=self.causal)\n\n        scaled_timestep = None\n        if self.timestep_conditioning:\n            if timestep is None:\n                raise ValueError(\"'timestep' parameter must be provided when 'timestep_conditioning' is True\")\n            scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample)\n\n        for up_block in self.up_blocks:\n            if isinstance(up_block, UNetMidBlock3D):\n                block_kwargs = {\n                    \"causal\": self.causal,\n                    \"timestep\": scaled_timestep if self.timestep_conditioning else None,\n                    \"generator\": generator,\n                }\n                sample = up_block(sample, **block_kwargs)\n            elif isinstance(up_block, ResnetBlock3D):\n                sample = up_block(sample, causal=self.causal, generator=generator)\n            else:\n                sample = up_block(sample, causal=self.causal)\n\n        sample = self.conv_norm_out(sample)\n\n        if self.timestep_conditioning:\n            embedded_timestep = self.last_time_embedder(\n                timestep=scaled_timestep.flatten(),\n                hidden_dtype=sample.dtype,\n            )\n            embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1)\n            ada_values = self.last_scale_shift_table[None, ..., None, None, None].to(\n                device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(\n                    batch_size,\n                    2,\n                    -1,\n                    embedded_timestep.shape[-3],\n                    embedded_timestep.shape[-2],\n                    embedded_timestep.shape[-1],\n                )\n            shift, scale = ada_values.unbind(dim=1)\n            sample = sample * (1 + scale) + shift\n\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample, causal=self.causal)\n\n        # Final spatial expansion: reverse the initial patchify from encoder\n        # Moves pixels from channels back to spatial dimensions\n        # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4\n        sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)\n\n        return sample\n\n    def _prepare_tiles(\n        self,\n        latent: torch.Tensor,\n        tiling_config: TilingConfig | None = None,\n    ) -> List[Tile]:\n        splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape)\n        mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape)\n        if tiling_config is not None and tiling_config.spatial_config is not None:\n            cfg = tiling_config.spatial_config\n            long_side = max(latent.shape[3], latent.shape[4])\n\n            def enable_on_axis(axis_idx: int, factor: int) -> None:\n                size = cfg.tile_size_in_pixels // factor\n                overlap = cfg.tile_overlap_in_pixels // factor\n                axis_length = latent.shape[axis_idx]\n                lower_threshold = max(2, overlap + 1)\n                tile_size = max(lower_threshold, round(size * axis_length / long_side))\n                splitters[axis_idx] = split_in_spatial(tile_size, overlap)\n                mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor)\n\n            enable_on_axis(3, self.video_downscale_factors.height)\n            enable_on_axis(4, self.video_downscale_factors.width)\n\n        if tiling_config is not None and tiling_config.temporal_config is not None:\n            cfg = tiling_config.temporal_config\n            tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time\n            overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time\n            splitters[2] = split_in_temporal(tile_size, overlap)\n            mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time)\n\n        return create_tiles(latent.shape, splitters, mappers)\n\n    def tiled_decode(\n        self,\n        latent: torch.Tensor,\n        tiling_config: TilingConfig | None = None,\n        timestep: torch.Tensor | None = None,\n        generator: torch.Generator | None = None,\n    ) -> Iterator[torch.Tensor]:\n        \"\"\"\n        Decode a latent tensor into video frames using tiled processing.\n        Splits the latent tensor into tiles, decodes each tile individually,\n        and yields video chunks as they become available.\n        Args:\n            latent: Input latent tensor (B, C, F', H', W').\n            tiling_config: Tiling configuration for the latent tensor.\n            timestep: Optional timestep for decoder conditioning.\n            generator: Optional random generator for deterministic decoding.\n        Yields:\n            Video chunks (B, C, T, H, W) by temporal slices;\n        \"\"\"\n\n        # Calculate full video shape from latent shape to get spatial dimensions\n        full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors)\n        tiles = self._prepare_tiles(latent, tiling_config)\n\n        temporal_groups = self._group_tiles_by_temporal_slice(tiles)\n\n        # State for temporal overlap handling\n        previous_chunk = None\n        previous_weights = None\n        previous_temporal_slice = None\n\n        for temporal_group_tiles in temporal_groups:\n            curr_temporal_slice = temporal_group_tiles[0].out_coords[2]\n\n            # Calculate the shape of the temporal buffer for this group of tiles.\n            # The temporal length depends on whether this is the first tile (starts at 0) or not.\n            # - First tile: (frames - 1) * scale + 1\n            # - Subsequent tiles: frames * scale\n            # This logic is handled by TemporalAxisMapping and reflected in out_coords.\n            temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop -\n                                                                   curr_temporal_slice.start,)\n\n            buffer = torch.zeros(\n                temporal_tile_buffer_shape.to_torch_shape(),\n                device=latent.device,\n                dtype=latent.dtype,\n            )\n\n            curr_weights = self._accumulate_temporal_group_into_buffer(\n                group_tiles=temporal_group_tiles,\n                buffer=buffer,\n                latent=latent,\n                timestep=timestep,\n                generator=generator,\n            )\n\n            # Blend with previous temporal chunk if it exists\n            if previous_chunk is not None:\n                # Check if current temporal slice overlaps with previous temporal slice\n                if previous_temporal_slice.stop > curr_temporal_slice.start:\n                    overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start\n                    temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None)\n\n                    # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer\n                    # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add\n                    # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the\n                    # previous buffers, then later normalize by weights.\n                    previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :]\n                    previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :,\n                                                                                         slice(0, overlap_len), :, :]\n\n                    buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :]\n                    curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :,\n                                                                                       temporal_overlap_slice, :, :]\n\n                # Yield the non-overlapping part of the previous chunk\n                previous_weights = previous_weights.clamp(min=1e-8)\n                yield_len = curr_temporal_slice.start - previous_temporal_slice.start\n                yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :]\n\n            # Update state for next iteration\n            previous_chunk = buffer\n            previous_weights = curr_weights\n            previous_temporal_slice = curr_temporal_slice\n\n        # Yield any remaining chunk\n        if previous_chunk is not None:\n            previous_weights = previous_weights.clamp(min=1e-8)\n            yield previous_chunk / previous_weights\n\n    def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]:\n        \"\"\"Group tiles by their temporal output slice.\"\"\"\n        if not tiles:\n            return []\n\n        groups = []\n        current_slice = tiles[0].out_coords[2]\n        current_group = []\n\n        for tile in tiles:\n            tile_slice = tile.out_coords[2]\n            if tile_slice == current_slice:\n                current_group.append(tile)\n            else:\n                groups.append(current_group)\n                current_slice = tile_slice\n                current_group = [tile]\n\n        # Add the final group\n        if current_group:\n            groups.append(current_group)\n\n        return groups\n\n    def _accumulate_temporal_group_into_buffer(\n        self,\n        group_tiles: List[Tile],\n        buffer: torch.Tensor,\n        latent: torch.Tensor,\n        timestep: torch.Tensor | None,\n        generator: torch.Generator | None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Decode and accumulate all tiles of a temporal group into a local buffer.\n        The buffer is local to the group and always starts at time 0; temporal coordinates\n        are rebased by subtracting temporal_slice.start.\n        \"\"\"\n        temporal_slice = group_tiles[0].out_coords[2]\n\n        weights = torch.zeros_like(buffer)\n\n        for tile in group_tiles:\n            decoded_tile = self.forward(latent[tile.in_coords], timestep, generator)\n            mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype)\n            temporal_offset = tile.out_coords[2].start - temporal_slice.start\n            # Use the tile's output coordinate length, not the decoded tile's length,\n            # as the decoder may produce a different number of frames than expected\n            expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start\n            decoded_temporal_len = decoded_tile.shape[2]\n\n            # Ensure we don't exceed the buffer or decoded tile bounds\n            actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset)\n\n            chunk_coords = (\n                slice(None),  # batch\n                slice(None),  # channels\n                slice(temporal_offset, temporal_offset + actual_temporal_len),\n                tile.out_coords[3],  # height\n                tile.out_coords[4],  # width\n            )\n\n            # Slice decoded_tile and mask to match the actual length we're writing\n            decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :]\n            mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask\n\n            buffer[chunk_coords] += decoded_slice * mask_slice\n            weights[chunk_coords] += mask_slice\n\n        return weights\n\n    def decode(\n        self,\n        latent: torch.Tensor,\n        tiled=False,\n        tile_size_in_pixels: Optional[int] = 512,\n        tile_overlap_in_pixels: Optional[int] = 128,\n        tile_size_in_frames: Optional[int] = 128,\n        tile_overlap_in_frames: Optional[int] = 24,\n    ) -> torch.Tensor:\n        if tiled:\n            tiling_config = TilingConfig(\n                spatial_config=SpatialTilingConfig(\n                    tile_size_in_pixels=tile_size_in_pixels,\n                    tile_overlap_in_pixels=tile_overlap_in_pixels,\n                ),\n                temporal_config=TemporalTilingConfig(\n                    tile_size_in_frames=tile_size_in_frames,\n                    tile_overlap_in_frames=tile_overlap_in_frames,\n                ),\n            )\n            tiles = self.tiled_decode(latent, tiling_config)\n            return torch.cat(list(tiles), dim=2)\n        else:\n            return self.forward(latent)\n\ndef decode_video(\n    latent: torch.Tensor,\n    video_decoder: LTX2VideoDecoder,\n    tiling_config: TilingConfig | None = None,\n    generator: torch.Generator | None = None,\n) -> Iterator[torch.Tensor]:\n    \"\"\"\n    Decode a video latent tensor with the given decoder.\n    Args:\n        latent: Tensor [c, f, h, w]\n        video_decoder: Decoder module.\n        tiling_config: Optional tiling settings.\n        generator: Optional random generator for deterministic decoding.\n    Yields:\n        Decoded chunk [f, h, w, c], uint8 in [0, 255].\n    \"\"\"\n\n    def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor:\n        frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8)\n        frames = rearrange(frames[0], \"c f h w -> f h w c\")\n        return frames\n\n    if tiling_config is not None:\n        for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):\n            return convert_to_uint8(frames)\n    else:\n        decoded_video = video_decoder(latent, generator=generator)\n        return convert_to_uint8(decoded_video)\n\n\ndef get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:\n    \"\"\"\n    Get the number of video chunks for a given number of frames and tiling configuration.\n    Args:\n        num_frames: Number of frames in the video.\n        tiling_config: Tiling configuration.\n    Returns:\n        Number of video chunks.\n    \"\"\"\n    if not tiling_config or not tiling_config.temporal_config:\n        return 1\n    cfg = tiling_config.temporal_config\n    frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames\n    return (num_frames - 1 + frame_stride - 1) // frame_stride\n\n\ndef split_in_spatial(size: int, overlap: int) -> SplitOperation:\n\n    def split(dimension_size: int) -> DimensionIntervals:\n        if dimension_size <= size:\n            return DEFAULT_SPLIT_OPERATION(dimension_size)\n        amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)\n        starts = [i * (size - overlap) for i in range(amount)]\n        ends = [start + size for start in starts]\n        ends[-1] = dimension_size\n        left_ramps = [0] + [overlap] * (amount - 1)\n        right_ramps = [overlap] * (amount - 1) + [0]\n        return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)\n\n    return split\n\n\ndef split_in_temporal(size: int, overlap: int) -> SplitOperation:\n    non_causal_split = split_in_spatial(size, overlap)\n\n    def split(dimension_size: int) -> DimensionIntervals:\n        if dimension_size <= size:\n            return DEFAULT_SPLIT_OPERATION(dimension_size)\n        intervals = non_causal_split(dimension_size)\n        starts = intervals.starts\n        starts[1:] = [s - 1 for s in starts[1:]]\n        left_ramps = intervals.left_ramps\n        left_ramps[1:] = [r + 1 for r in left_ramps[1:]]\n        return replace(intervals, starts=starts, left_ramps=left_ramps)\n\n    return split\n\n\ndef to_mapping_operation(\n    map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]],\n    scale: int,\n) -> MappingOperation:\n\n    def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]:\n        output_slices: list[slice] = []\n        masks_1d: list[torch.Tensor | None] = []\n        number_of_slices = len(intervals.starts)\n        for i in range(number_of_slices):\n            start = intervals.starts[i]\n            end = intervals.ends[i]\n            left_ramp = intervals.left_ramps[i]\n            right_ramp = intervals.right_ramps[i]\n            output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale)\n            output_slices.append(output_slice)\n            masks_1d.append(mask_1d)\n        return output_slices, masks_1d\n\n    return map_op\n\n\ndef map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]:\n    start = begin * scale\n    stop = 1 + (end - 1) * scale\n    left_ramp = 1 + (left_ramp - 1) * scale\n    right_ramp = right_ramp * scale\n\n    return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True)\n\n\ndef map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]:\n    start = begin * scale\n    stop = end * scale\n    left_ramp = left_ramp * scale\n    right_ramp = right_ramp * scale\n\n    return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False)\n"
  },
  {
    "path": "diffsynth/models/model_loader.py",
    "content": "from ..core.loader import load_model, hash_model_file\nfrom ..core.vram import AutoWrappedModule\nfrom ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS\nimport importlib, json, torch\n\n\nclass ModelPool:\n    def __init__(self):\n        self.model = []\n        self.model_name = []\n        self.model_path = []\n        \n    def import_model_class(self, model_class):\n        split = model_class.rfind(\".\")\n        model_resource, model_class = model_class[:split], model_class[split+1:]\n        model_class = importlib.import_module(model_resource).__getattribute__(model_class)\n        return model_class\n    \n    def need_to_enable_vram_management(self, vram_config):\n        return vram_config[\"offload_dtype\"] is not None and vram_config[\"offload_device\"] is not None\n    \n    def fetch_module_map(self, model_class, vram_config):\n        if self.need_to_enable_vram_management(vram_config):\n            if model_class in VRAM_MANAGEMENT_MODULE_MAPS:\n                vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()\n                module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}\n            else:\n                module_map = {self.import_model_class(model_class): AutoWrappedModule}\n        else:\n            module_map = None\n        return module_map\n    \n    def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):\n        model_class = self.import_model_class(config[\"model_class\"])\n        model_config = config.get(\"extra_kwargs\", {})\n        if \"state_dict_converter\" in config:\n            state_dict_converter = self.import_model_class(config[\"state_dict_converter\"])\n        else:\n            state_dict_converter = None\n        module_map = self.fetch_module_map(config[\"model_class\"], vram_config)\n        model = load_model(\n            model_class, path, model_config,\n            vram_config[\"computation_dtype\"], vram_config[\"computation_device\"],\n            state_dict_converter,\n            use_disk_map=True,\n            vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,\n            state_dict=state_dict,\n        )\n        return model\n    \n    def default_vram_config(self):\n        vram_config = {\n            \"offload_dtype\": None,\n            \"offload_device\": None,\n            \"onload_dtype\": torch.bfloat16,\n            \"onload_device\": \"cpu\",\n            \"preparing_dtype\": torch.bfloat16,\n            \"preparing_device\": \"cpu\",\n            \"computation_dtype\": torch.bfloat16,\n            \"computation_device\": \"cpu\",\n        }\n        return vram_config\n    \n    def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):\n        print(f\"Loading models from: {json.dumps(path, indent=4)}\")\n        if vram_config is None:\n            vram_config = self.default_vram_config()\n        model_hash = hash_model_file(path)\n        loaded = False\n        for config in MODEL_CONFIGS:\n            if config[\"model_hash\"] == model_hash:\n                model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)\n                if clear_parameters: self.clear_parameters(model)\n                self.model.append(model)\n                model_name = config[\"model_name\"]\n                self.model_name.append(model_name)\n                self.model_path.append(path)\n                model_info = {\"model_name\": model_name, \"model_class\": config[\"model_class\"], \"extra_kwargs\": config.get(\"extra_kwargs\")}\n                print(f\"Loaded model: {json.dumps(model_info, indent=4)}\")\n                loaded = True\n        if not loaded:\n            raise ValueError(f\"Cannot detect the model type. File: {path}. Model hash: {model_hash}\")\n    \n    def fetch_model(self, model_name, index=None):\n        fetched_models = []\n        fetched_model_paths = []\n        for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):\n            if model_name == model_name_:\n                fetched_models.append(model)\n                fetched_model_paths.append(model_path)\n        if len(fetched_models) == 0:\n            print(f\"No {model_name} models available. This is not an error.\")\n            model = None\n        elif len(fetched_models) == 1:\n            print(f\"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.\")\n            model = fetched_models[0]\n        else:\n            if index is None:\n                model = fetched_models[0]\n                print(f\"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.\")\n            elif isinstance(index, int):\n                model = fetched_models[:index]\n                print(f\"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.\")\n            else:\n                model = fetched_models\n                print(f\"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.\")\n        return model\n\n    def clear_parameters(self, model: torch.nn.Module):\n        for name, module in model.named_children():\n            self.clear_parameters(module)\n        for name, param in model.named_parameters(recurse=False):\n            setattr(model, name, None)\n"
  },
  {
    "path": "diffsynth/models/mova_audio_dit.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d\nfrom einops import rearrange\nfrom ..core import gradient_checkpoint_forward\n\ndef precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):\n    f_freqs_cis = precompute_freqs_cis(dim, end, theta)\n    return f_freqs_cis.chunk(3, dim=-1)\n\nclass MovaAudioDit(WanModel):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        head_dim = kwargs.get(\"dim\", 1536) // kwargs.get(\"num_heads\", 12)\n        self.freqs = precompute_freqs_cis_1d(head_dim)\n        self.patch_embedding = nn.Conv1d(\n            kwargs.get(\"in_dim\", 128), kwargs.get(\"dim\", 1536), kernel_size=[1], stride=[1]\n        )\n\n    def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):\n        self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)\n\n    def forward(self,\n                x: torch.Tensor,\n                timestep: torch.Tensor,\n                context: torch.Tensor,\n                use_gradient_checkpointing: bool = False,\n                use_gradient_checkpointing_offload: bool = False,\n                **kwargs,\n                ):\n        t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))\n        t_mod = self.time_projection(t).unflatten(1, (6, self.dim))\n        context = self.text_embedding(context)\n        x, (f, ) = self.patchify(x)\n        freqs = torch.cat([\n            self.freqs[0][:f].view(f, -1).expand(f, -1),\n            self.freqs[1][:f].view(f, -1).expand(f, -1),\n            self.freqs[2][:f].view(f, -1).expand(f, -1),\n        ], dim=-1).reshape(f, 1, -1).to(x.device)\n\n        for block in self.blocks:\n            x = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                x, context, t_mod, freqs,\n            )\n        x = self.head(x, t)\n        x = self.unpatchify(x, (f, ))\n        return x\n\n    def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):\n        return rearrange(\n            x, 'b f (p c) -> b c (f p)',\n            f=grid_size[0],\n            p=self.patch_size[0]\n        )\n"
  },
  {
    "path": "diffsynth/models/mova_audio_vae.py",
    "content": "import math\nfrom typing import List, Union\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn.utils import weight_norm\nimport torch.nn.functional as F\nfrom einops import rearrange\n\ndef WNConv1d(*args, **kwargs):\n    return weight_norm(nn.Conv1d(*args, **kwargs))\n\n\ndef WNConvTranspose1d(*args, **kwargs):\n    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))\n\n\n# Scripting this brings model speed up 1.4x\n@torch.jit.script\ndef snake(x, alpha):\n    shape = x.shape\n    x = x.reshape(shape[0], shape[1], -1)\n    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)\n    x = x.reshape(shape)\n    return x\n\n\nclass Snake1d(nn.Module):\n    def __init__(self, channels):\n        super().__init__()\n        self.alpha = nn.Parameter(torch.ones(1, channels, 1))\n\n    def forward(self, x):\n        return snake(x, self.alpha)\n\n\nclass VectorQuantize(nn.Module):\n    \"\"\"\n    Implementation of VQ similar to Karpathy's repo:\n    https://github.com/karpathy/deep-vector-quantization\n    Additionally uses following tricks from Improved VQGAN\n    (https://arxiv.org/pdf/2110.04627.pdf):\n        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space\n            for improved codebook usage\n        2. l2-normalized codes: Converts euclidean distance to cosine similarity which\n            improves training stability\n    \"\"\"\n\n    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):\n        super().__init__()\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n\n        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)\n        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)\n        self.codebook = nn.Embedding(codebook_size, codebook_dim)\n\n    def forward(self, z):\n        \"\"\"Quantized the input tensor using a fixed codebook and returns\n        the corresponding codebook vectors\n\n        Parameters\n        ----------\n        z : Tensor[B x D x T]\n\n        Returns\n        -------\n        Tensor[B x D x T]\n            Quantized continuous representation of input\n        Tensor[1]\n            Commitment loss to train encoder to predict vectors closer to codebook\n            entries\n        Tensor[1]\n            Codebook loss to update the codebook\n        Tensor[B x T]\n            Codebook indices (quantized discrete representation of input)\n        Tensor[B x D x T]\n            Projected latents (continuous representation of input before quantization)\n        \"\"\"\n\n        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space\n        z_e = self.in_proj(z)  # z_e : (B x D x T)\n        z_q, indices = self.decode_latents(z_e)\n\n        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction=\"none\").mean([1, 2])\n        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction=\"none\").mean([1, 2])\n\n        z_q = (\n            z_e + (z_q - z_e).detach()\n        )  # noop in forward pass, straight-through gradient estimator in backward pass\n\n        z_q = self.out_proj(z_q)\n\n        return z_q, commitment_loss, codebook_loss, indices, z_e\n\n    def embed_code(self, embed_id):\n        return F.embedding(embed_id, self.codebook.weight)\n\n    def decode_code(self, embed_id):\n        return self.embed_code(embed_id).transpose(1, 2)\n\n    def decode_latents(self, latents):\n        encodings = rearrange(latents, \"b d t -> (b t) d\")\n        codebook = self.codebook.weight  # codebook: (N x D)\n\n        # L2 normalize encodings and codebook (ViT-VQGAN)\n        encodings = F.normalize(encodings)\n        codebook = F.normalize(codebook)\n\n        # Compute euclidean distance with codebook\n        dist = (\n            encodings.pow(2).sum(1, keepdim=True)\n            - 2 * encodings @ codebook.t()\n            + codebook.pow(2).sum(1, keepdim=True).t()\n        )\n        indices = rearrange((-dist).max(1)[1], \"(b t) -> b t\", b=latents.size(0))\n        z_q = self.decode_code(indices)\n        return z_q, indices\n\n\nclass ResidualVectorQuantize(nn.Module):\n    \"\"\"\n    Introduced in SoundStream: An end2end neural audio codec\n    https://arxiv.org/abs/2107.03312\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim: int = 512,\n        n_codebooks: int = 9,\n        codebook_size: int = 1024,\n        codebook_dim: Union[int, list] = 8,\n        quantizer_dropout: float = 0.0,\n    ):\n        super().__init__()\n        if isinstance(codebook_dim, int):\n            codebook_dim = [codebook_dim for _ in range(n_codebooks)]\n\n        self.n_codebooks = n_codebooks\n        self.codebook_dim = codebook_dim\n        self.codebook_size = codebook_size\n\n        self.quantizers = nn.ModuleList(\n            [\n                VectorQuantize(input_dim, codebook_size, codebook_dim[i])\n                for i in range(n_codebooks)\n            ]\n        )\n        self.quantizer_dropout = quantizer_dropout\n\n    def forward(self, z, n_quantizers: int = None):\n        \"\"\"Quantized the input tensor using a fixed set of `n` codebooks and returns\n        the corresponding codebook vectors\n        Parameters\n        ----------\n        z : Tensor[B x D x T]\n        n_quantizers : int, optional\n            No. of quantizers to use\n            (n_quantizers < self.n_codebooks ex: for quantizer dropout)\n            Note: if `self.quantizer_dropout` is True, this argument is ignored\n                when in training mode, and a random number of quantizers is used.\n        Returns\n        -------\n        dict\n            A dictionary with the following keys:\n\n            \"z\" : Tensor[B x D x T]\n                Quantized continuous representation of input\n            \"codes\" : Tensor[B x N x T]\n                Codebook indices for each codebook\n                (quantized discrete representation of input)\n            \"latents\" : Tensor[B x N*D x T]\n                Projected latents (continuous representation of input before quantization)\n            \"vq/commitment_loss\" : Tensor[1]\n                Commitment loss to train encoder to predict vectors closer to codebook\n                entries\n            \"vq/codebook_loss\" : Tensor[1]\n                Codebook loss to update the codebook\n        \"\"\"\n        z_q = 0\n        residual = z\n        commitment_loss = 0\n        codebook_loss = 0\n\n        codebook_indices = []\n        latents = []\n\n        if n_quantizers is None:\n            n_quantizers = self.n_codebooks\n        if self.training:\n            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1\n            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))\n            n_dropout = int(z.shape[0] * self.quantizer_dropout)\n            n_quantizers[:n_dropout] = dropout[:n_dropout]\n            n_quantizers = n_quantizers.to(z.device)\n\n        for i, quantizer in enumerate(self.quantizers):\n            if self.training is False and i >= n_quantizers:\n                break\n\n            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(\n                residual\n            )\n\n            # Create mask to apply quantizer dropout\n            mask = (\n                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers\n            )\n            z_q = z_q + z_q_i * mask[:, None, None]\n            residual = residual - z_q_i\n\n            # Sum losses\n            commitment_loss += (commitment_loss_i * mask).mean()\n            codebook_loss += (codebook_loss_i * mask).mean()\n\n            codebook_indices.append(indices_i)\n            latents.append(z_e_i)\n\n        codes = torch.stack(codebook_indices, dim=1)\n        latents = torch.cat(latents, dim=1)\n\n        return z_q, codes, latents, commitment_loss, codebook_loss\n\n    def from_codes(self, codes: torch.Tensor):\n        \"\"\"Given the quantized codes, reconstruct the continuous representation\n        Parameters\n        ----------\n        codes : Tensor[B x N x T]\n            Quantized discrete representation of input\n        Returns\n        -------\n        Tensor[B x D x T]\n            Quantized continuous representation of input\n        \"\"\"\n        z_q = 0.0\n        z_p = []\n        n_codebooks = codes.shape[1]\n        for i in range(n_codebooks):\n            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])\n            z_p.append(z_p_i)\n\n            z_q_i = self.quantizers[i].out_proj(z_p_i)\n            z_q = z_q + z_q_i\n        return z_q, torch.cat(z_p, dim=1), codes\n\n    def from_latents(self, latents: torch.Tensor):\n        \"\"\"Given the unquantized latents, reconstruct the\n        continuous representation after quantization.\n\n        Parameters\n        ----------\n        latents : Tensor[B x N x T]\n            Continuous representation of input after projection\n\n        Returns\n        -------\n        Tensor[B x D x T]\n            Quantized representation of full-projected space\n        Tensor[B x D x T]\n            Quantized representation of latent space\n        \"\"\"\n        z_q = 0\n        z_p = []\n        codes = []\n        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])\n\n        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[\n            0\n        ]\n        for i in range(n_codebooks):\n            j, k = dims[i], dims[i + 1]\n            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])\n            z_p.append(z_p_i)\n            codes.append(codes_i)\n\n            z_q_i = self.quantizers[i].out_proj(z_p_i)\n            z_q = z_q + z_q_i\n\n        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        else:\n            if other is None:\n                return 0.5 * torch.mean(\n                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,\n                    dim=[1, 2],\n                )\n            else:\n                return 0.5 * torch.mean(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var\n                    - 1.0\n                    - self.logvar\n                    + other.logvar,\n                    dim=[1, 2],\n                )\n\n    def nll(self, sample, dims=[1, 2]):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims,\n        )\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]\n\n    return 0.5 * (\n        -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n\n\ndef init_weights(m):\n    if isinstance(m, nn.Conv1d):\n        nn.init.trunc_normal_(m.weight, std=0.02)\n        nn.init.constant_(m.bias, 0)\n\n\nclass ResidualUnit(nn.Module):\n    def __init__(self, dim: int = 16, dilation: int = 1):\n        super().__init__()\n        pad = ((7 - 1) * dilation) // 2\n        self.block = nn.Sequential(\n            Snake1d(dim),\n            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),\n            Snake1d(dim),\n            WNConv1d(dim, dim, kernel_size=1),\n        )\n\n    def forward(self, x):\n        y = self.block(x)\n        pad = (x.shape[-1] - y.shape[-1]) // 2\n        if pad > 0:\n            x = x[..., pad:-pad]\n        return x + y\n\n\nclass EncoderBlock(nn.Module):\n    def __init__(self, dim: int = 16, stride: int = 1):\n        super().__init__()\n        self.block = nn.Sequential(\n            ResidualUnit(dim // 2, dilation=1),\n            ResidualUnit(dim // 2, dilation=3),\n            ResidualUnit(dim // 2, dilation=9),\n            Snake1d(dim // 2),\n            WNConv1d(\n                dim // 2,\n                dim,\n                kernel_size=2 * stride,\n                stride=stride,\n                padding=math.ceil(stride / 2),\n            ),\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        d_model: int = 64,\n        strides: list = [2, 4, 8, 8],\n        d_latent: int = 64,\n    ):\n        super().__init__()\n        # Create first convolution\n        self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]\n\n        # Create EncoderBlocks that double channels as they downsample by `stride`\n        for stride in strides:\n            d_model *= 2\n            self.block += [EncoderBlock(d_model, stride=stride)]\n\n        # Create last convolution\n        self.block += [\n            Snake1d(d_model),\n            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),\n        ]\n\n        # Wrap black into nn.Sequential\n        self.block = nn.Sequential(*self.block)\n        self.enc_dim = d_model\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass DecoderBlock(nn.Module):\n    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):\n        super().__init__()\n        self.block = nn.Sequential(\n            Snake1d(input_dim),\n            WNConvTranspose1d(\n                input_dim,\n                output_dim,\n                kernel_size=2 * stride,\n                stride=stride,\n                padding=math.ceil(stride / 2),\n                output_padding=stride % 2,\n            ),\n            ResidualUnit(output_dim, dilation=1),\n            ResidualUnit(output_dim, dilation=3),\n            ResidualUnit(output_dim, dilation=9),\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        input_channel,\n        channels,\n        rates,\n        d_out: int = 1,\n    ):\n        super().__init__()\n\n        # Add first conv layer\n        layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]\n\n        # Add upsampling + MRF blocks\n        for i, stride in enumerate(rates):\n            input_dim = channels // 2**i\n            output_dim = channels // 2 ** (i + 1)\n            layers += [DecoderBlock(input_dim, output_dim, stride)]\n\n        # Add final conv layer\n        layers += [\n            Snake1d(output_dim),\n            WNConv1d(output_dim, d_out, kernel_size=7, padding=3),\n            nn.Tanh(),\n        ]\n\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\n\nclass DacVAE(nn.Module):\n\n    def __init__(\n        self,\n        encoder_dim: int = 128,\n        encoder_rates: List[int] = [2, 3, 4, 5, 8],\n        latent_dim: int = 128,\n        decoder_dim: int = 2048,\n        decoder_rates: List[int] = [8, 5, 4, 3, 2],\n        n_codebooks: int = 9,\n        codebook_size: int = 1024,\n        codebook_dim: Union[int, list] = 8,\n        quantizer_dropout: bool = False,\n        sample_rate: int = 48000,\n        continuous: bool = True,\n        use_weight_norm: bool = False,\n    ):\n        super().__init__()\n\n        self.encoder_dim = encoder_dim\n        self.encoder_rates = encoder_rates\n        self.decoder_dim = decoder_dim\n        self.decoder_rates = decoder_rates\n        self.sample_rate = sample_rate\n        self.continuous = continuous\n        self.use_weight_norm = use_weight_norm\n\n        if latent_dim is None:\n            latent_dim = encoder_dim * (2 ** len(encoder_rates))\n\n        self.latent_dim = latent_dim\n\n        self.hop_length = np.prod(encoder_rates)\n        self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)\n\n        if not continuous:\n            self.n_codebooks = n_codebooks\n            self.codebook_size = codebook_size\n            self.codebook_dim = codebook_dim\n            self.quantizer = ResidualVectorQuantize(\n                input_dim=latent_dim,\n                n_codebooks=n_codebooks,\n                codebook_size=codebook_size,\n                codebook_dim=codebook_dim,\n                quantizer_dropout=quantizer_dropout,\n            )\n        else:\n            self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)\n            self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)\n\n        self.decoder = Decoder(\n            latent_dim,\n            decoder_dim,\n            decoder_rates,\n        )\n        self.sample_rate = sample_rate\n        self.apply(init_weights)\n\n        self.delay = self.get_delay()\n\n        if not self.use_weight_norm:\n            self.remove_weight_norm()\n\n    def get_delay(self):\n        # Any number works here, delay is invariant to input length\n        l_out = self.get_output_length(0)\n        L = l_out\n\n        layers = []\n        for layer in self.modules():\n            if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):\n                layers.append(layer)\n\n        for layer in reversed(layers):\n            d = layer.dilation[0]\n            k = layer.kernel_size[0]\n            s = layer.stride[0]\n\n            if isinstance(layer, nn.ConvTranspose1d):\n                L = ((L - d * (k - 1) - 1) / s) + 1\n            elif isinstance(layer, nn.Conv1d):\n                L = (L - 1) * s + d * (k - 1) + 1\n\n            L = math.ceil(L)\n\n        l_in = L\n\n        return (l_in - l_out) // 2\n\n    def get_output_length(self, input_length):\n        L = input_length\n        # Calculate output length\n        for layer in self.modules():\n            if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):\n                d = layer.dilation[0]\n                k = layer.kernel_size[0]\n                s = layer.stride[0]\n\n                if isinstance(layer, nn.Conv1d):\n                    L = ((L - d * (k - 1) - 1) / s) + 1\n                elif isinstance(layer, nn.ConvTranspose1d):\n                    L = (L - 1) * s + d * (k - 1) + 1\n\n                L = math.floor(L)\n        return L\n\n    @property\n    def dtype(self):\n        \"\"\"Get the dtype of the model parameters.\"\"\"\n        # Return the dtype of the first parameter found\n        for param in self.parameters():\n            return param.dtype\n        return torch.float32  # fallback\n\n    @property\n    def device(self):\n        \"\"\"Get the device of the model parameters.\"\"\"\n        # Return the device of the first parameter found\n        for param in self.parameters():\n            return param.device\n        return torch.device('cpu')  # fallback\n\n    def preprocess(self, audio_data, sample_rate):\n        if sample_rate is None:\n            sample_rate = self.sample_rate\n        assert sample_rate == self.sample_rate\n\n        length = audio_data.shape[-1]\n        right_pad = math.ceil(length / self.hop_length) * self.hop_length - length\n        audio_data = nn.functional.pad(audio_data, (0, right_pad))\n\n        return audio_data\n\n    def encode(\n        self,\n        audio_data: torch.Tensor,\n        n_quantizers: int = None,\n    ):\n        \"\"\"Encode given audio data and return quantized latent codes\n\n        Parameters\n        ----------\n        audio_data : Tensor[B x 1 x T]\n            Audio data to encode\n        n_quantizers : int, optional\n            Number of quantizers to use, by default None\n            If None, all quantizers are used.\n\n        Returns\n        -------\n        dict\n            A dictionary with the following keys:\n            \"z\" : Tensor[B x D x T]\n                Quantized continuous representation of input\n            \"codes\" : Tensor[B x N x T]\n                Codebook indices for each codebook\n                (quantized discrete representation of input)\n            \"latents\" : Tensor[B x N*D x T]\n                Projected latents (continuous representation of input before quantization)\n            \"vq/commitment_loss\" : Tensor[1]\n                Commitment loss to train encoder to predict vectors closer to codebook\n                entries\n            \"vq/codebook_loss\" : Tensor[1]\n                Codebook loss to update the codebook\n            \"length\" : int\n                Number of samples in input audio\n        \"\"\"\n        z = self.encoder(audio_data)  # [B x D x T]\n        if not self.continuous:\n            z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)\n        else:\n            z = self.quant_conv(z)  # [B x 2D x T]\n            z = DiagonalGaussianDistribution(z)\n            codes, latents, commitment_loss, codebook_loss = None, None, 0, 0\n\n        return z, codes, latents, commitment_loss, codebook_loss\n\n    def decode(self, z: torch.Tensor):\n        \"\"\"Decode given latent codes and return audio data\n\n        Parameters\n        ----------\n        z : Tensor[B x D x T]\n            Quantized continuous representation of input\n        length : int, optional\n            Number of samples in output audio, by default None\n\n        Returns\n        -------\n        dict\n            A dictionary with the following keys:\n            \"audio\" : Tensor[B x 1 x length]\n                Decoded audio data.\n        \"\"\"\n        if not self.continuous:\n            audio = self.decoder(z)\n        else:\n            z = self.post_quant_conv(z)\n            audio = self.decoder(z)\n\n        return audio\n\n    def forward(\n        self,\n        audio_data: torch.Tensor,\n        sample_rate: int = None,\n        n_quantizers: int = None,\n    ):\n        \"\"\"Model forward pass\n\n        Parameters\n        ----------\n        audio_data : Tensor[B x 1 x T]\n            Audio data to encode\n        sample_rate : int, optional\n            Sample rate of audio data in Hz, by default None\n            If None, defaults to `self.sample_rate`\n        n_quantizers : int, optional\n            Number of quantizers to use, by default None.\n            If None, all quantizers are used.\n\n        Returns\n        -------\n        dict\n            A dictionary with the following keys:\n            \"z\" : Tensor[B x D x T]\n                Quantized continuous representation of input\n            \"codes\" : Tensor[B x N x T]\n                Codebook indices for each codebook\n                (quantized discrete representation of input)\n            \"latents\" : Tensor[B x N*D x T]\n                Projected latents (continuous representation of input before quantization)\n            \"vq/commitment_loss\" : Tensor[1]\n                Commitment loss to train encoder to predict vectors closer to codebook\n                entries\n            \"vq/codebook_loss\" : Tensor[1]\n                Codebook loss to update the codebook\n            \"length\" : int\n                Number of samples in input audio\n            \"audio\" : Tensor[B x 1 x length]\n                Decoded audio data.\n        \"\"\"\n        length = audio_data.shape[-1]\n        audio_data = self.preprocess(audio_data, sample_rate)\n        if not self.continuous:\n            z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)\n\n            x = self.decode(z)\n            return {\n                \"audio\": x[..., :length],\n                \"z\": z,\n                \"codes\": codes,\n                \"latents\": latents,\n                \"vq/commitment_loss\": commitment_loss,\n                \"vq/codebook_loss\": codebook_loss,\n            }\n        else:\n            posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)\n            z = posterior.sample()\n            x = self.decode(z)\n\n            kl_loss = posterior.kl()\n            kl_loss = kl_loss.mean()\n\n            return {\n                \"audio\": x[..., :length],\n                \"z\": z,\n                \"kl_loss\": kl_loss,\n            }\n\n    def remove_weight_norm(self):\n        \"\"\"\n        Remove weight_norm from all modules in the model.\n        This fuses the weight_g and weight_v parameters into a single weight parameter.\n        Should be called before inference for better performance.\n        Returns:\n            self: The model with weight_norm removed\n        \"\"\"\n        from torch.nn.utils import remove_weight_norm\n        num_removed = 0\n        for name, module in list(self.named_modules()):\n            if hasattr(module, \"_forward_pre_hooks\"):\n                for hook_id, hook in list(module._forward_pre_hooks.items()):\n                    if \"WeightNorm\" in str(type(hook)):\n                        try:\n                            remove_weight_norm(module)\n                            num_removed += 1\n                            # print(f\"Removed weight_norm from: {name}\")\n                        except ValueError as e:\n                            print(f\"Failed to remove weight_norm from {name}: {e}\")\n        if num_removed > 0:\n            # print(f\"Successfully removed weight_norm from {num_removed} modules\")\n            self.use_weight_norm = False\n        else:\n            print(\"No weight_norm found in the model\")\n        return self\n"
  },
  {
    "path": "diffsynth/models/mova_dual_tower_bridge.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Dict, List, Tuple, Optional\nfrom einops import rearrange\nfrom .wan_video_dit import AttentionModule, RMSNorm\nfrom ..core import gradient_checkpoint_forward\n\nclass RotaryEmbedding(nn.Module):\n    inv_freq: torch.Tensor  # fix linting for `register_buffer`\n\n    def __init__(self, base: float, dim: int, device=None):\n        super().__init__()\n        self.base = base\n        self.dim = dim\n        self.attention_scaling = 1.0\n\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n@torch.compile(fullgraph=True)\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass PerFrameAttentionPooling(nn.Module):\n    \"\"\"\n    Per-frame multi-head attention pooling.\n\n    Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a\n    single-query attention pooling over the H*W tokens for each time frame, producing\n    [B, T, D].\n\n    Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).\n    \"\"\"\n\n    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__()\n        assert dim % num_heads == 0, \"dim must be divisible by num_heads\"\n        self.dim = dim\n        self.num_heads = num_heads\n\n        self.probe = nn.Parameter(torch.randn(1, 1, dim))\n        nn.init.normal_(self.probe, std=0.02)\n\n        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)\n        self.layernorm = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: [B, L, D], where L = T*H*W\n            grid_size: (T, H, W)\n        Returns:\n            pooled: [B, T, D]\n        \"\"\"\n        B, L, D = x.shape\n        T, H, W = grid_size\n        assert D == self.dim, f\"Channel dimension mismatch: D={D} vs dim={self.dim}\"\n        assert L == T * H * W, f\"Flattened length mismatch: L={L} vs T*H*W={T*H*W}\"\n\n        S = H * W\n        # Re-arrange tokens grouped by frame.\n        x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D)  # [B*T, S, D]\n\n        # A learnable probe as the query (one query per frame).\n        probe = self.probe.expand(B * T, -1, -1)  # [B*T, 1, D]\n\n        # Attention pooling: query=probe, key/value=H*W tokens within the frame.\n        pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0]  # [B*T, 1, D]\n        pooled_bt_d = pooled_bt_1_d.squeeze(1)  # [B*T, D]\n\n        # Restore to [B, T, D].\n        pooled = pooled_bt_d.view(B, T, D)\n        pooled = self.layernorm(pooled)\n        return pooled\n\n\nclass CrossModalInteractionController:\n    \"\"\"\n    Strategy class that controls interactions between two towers.\n    Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).\n    \"\"\"\n\n    def __init__(self, visual_layers: int = 30, audio_layers: int = 30):\n        self.visual_layers = visual_layers\n        self.audio_layers = audio_layers\n        self.min_layers = min(visual_layers, audio_layers)\n\n    def get_interaction_layers(self, strategy: str = \"shallow_focus\") -> Dict[str, List[Tuple[int, int]]]:\n        \"\"\"\n        Get interaction layer mappings.\n\n        Args:\n            strategy: interaction strategy\n                - \"shallow_focus\": emphasize shallow layers to avoid deep-layer asymmetry\n                - \"distributed\": distributed interactions across the network\n                - \"progressive\": dense shallow interactions, sparse deeper interactions\n                - \"custom\": custom interaction layers\n\n        Returns:\n            A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).\n        \"\"\"\n\n        if strategy == \"shallow_focus\":\n            # Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.\n            num_interact = min(10, self.min_layers // 3)\n            interact_layers = list(range(0, num_interact))\n\n        elif strategy == \"distributed\":\n            # Distribute interactions across the network (every few layers).\n            step = 3\n            interact_layers = list(range(0, self.min_layers, step))\n\n        elif strategy == \"progressive\":\n            # Progressive: dense shallow interactions, sparse deeper interactions.\n            shallow = list(range(0, min(8, self.min_layers)))  # Dense for the first 8 layers.\n            if self.min_layers > 8:\n                deep = list(range(8, self.min_layers, 3))  # Every 3 layers afterwards.\n                interact_layers = shallow + deep\n            else:\n                interact_layers = shallow\n\n        elif strategy == \"custom\":\n            # Custom strategy: adjust as needed.\n            interact_layers = [0, 2, 4, 6, 8, 12, 16, 20]  # Explicit layer indices.\n            interact_layers = [i for i in interact_layers if i < self.min_layers]\n\n        elif strategy == \"full\":\n            interact_layers = list(range(0, self.min_layers))\n\n        else:\n            raise ValueError(f\"Unknown interaction strategy: {strategy}\")\n\n        # Build bidirectional mapping.\n        mapping = {\n            'v2a': [(i, i) for i in interact_layers],  # visual layer i -> audio layer i\n            'a2v': [(i, i) for i in interact_layers]   # audio layer i -> visual layer i\n        }\n\n        return mapping\n\n    def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:\n        \"\"\"\n        Check whether a given layer should interact.\n\n        Args:\n            layer_idx: current layer index\n            direction: interaction direction ('v2a' or 'a2v')\n            interaction_mapping: interaction mapping table\n\n        Returns:\n            bool: whether to interact\n        \"\"\"\n        if direction not in interaction_mapping:\n            return False\n\n        return any(src == layer_idx for src, _ in interaction_mapping[direction])\n\n\nclass ConditionalCrossAttention(nn.Module):\n    def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__()\n        self.q_dim = dim\n        self.kv_dim = kv_dim\n        self.num_heads = num_heads\n        self.head_dim = self.q_dim // num_heads\n\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(kv_dim, dim)\n        self.v = nn.Linear(kv_dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n\n        self.attn = AttentionModule(self.num_heads)\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):\n        ctx = y\n        q = self.norm_q(self.q(x))\n        k = self.norm_k(self.k(ctx))\n        v = self.v(ctx)\n        if x_freqs is not None:\n            x_cos, x_sin = x_freqs\n            B, L, _ = q.shape\n            q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)\n            x_cos = x_cos.to(q_view.dtype).to(q_view.device)\n            x_sin = x_sin.to(q_view.dtype).to(q_view.device)\n            # Expect x_cos/x_sin shape: [B or 1, L, head_dim]\n            q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)\n            q = rearrange(q_view, 'b l h d -> b l (h d)')\n        if y_freqs is not None:\n            y_cos, y_sin = y_freqs\n            Bc, Lc, _ = k.shape\n            k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)\n            y_cos = y_cos.to(k_view.dtype).to(k_view.device)\n            y_sin = y_sin.to(k_view.dtype).to(k_view.device)\n            # Expect y_cos/y_sin shape: [B or 1, L, head_dim]\n            _, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)\n            k = rearrange(k_view, 'b l h d -> b l (h d)')\n        x = self.attn(q, k, v)\n        return self.o(x)\n\n\n# from diffusers.models.attention import AdaLayerNorm\nclass AdaLayerNorm(nn.Module):\n    r\"\"\"\n    Norm layer modified to incorporate timestep embeddings.\n\n    Parameters:\n        embedding_dim (`int`): The size of each embedding vector.\n        num_embeddings (`int`, *optional*): The size of the embeddings dictionary.\n        output_dim (`int`, *optional*):\n        norm_elementwise_affine (`bool`, defaults to `False):\n        norm_eps (`bool`, defaults to `False`):\n        chunk_dim (`int`, defaults to `0`):\n    \"\"\"\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        num_embeddings: Optional[int] = None,\n        output_dim: Optional[int] = None,\n        norm_elementwise_affine: bool = False,\n        norm_eps: float = 1e-5,\n        chunk_dim: int = 0,\n    ):\n        super().__init__()\n\n        self.chunk_dim = chunk_dim\n        output_dim = output_dim or embedding_dim * 2\n\n        if num_embeddings is not None:\n            self.emb = nn.Embedding(num_embeddings, embedding_dim)\n        else:\n            self.emb = None\n\n        self.silu = nn.SiLU()\n        self.linear = nn.Linear(embedding_dim, output_dim)\n        self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)\n\n    def forward(\n        self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        if self.emb is not None:\n            temb = self.emb(timestep)\n\n        temb = self.linear(self.silu(temb))\n\n        if self.chunk_dim == 2:\n            scale, shift = temb.chunk(2, dim=2)\n            # print(f\"{x.shape = }, {scale.shape = }, {shift.shape = }\")\n        elif self.chunk_dim == 1:\n            # This is a bit weird why we have the order of \"shift, scale\" here and \"scale, shift\" in the\n            # other if-branch. This branch is specific to CogVideoX and OmniGen for now.\n            shift, scale = temb.chunk(2, dim=1)\n            shift = shift[:, None, :]\n            scale = scale[:, None, :]\n        else:\n            scale, shift = temb.chunk(2, dim=0)\n\n        x = self.norm(x) * (1 + scale) + shift\n        return x\n\n\nclass ConditionalCrossAttentionBlock(nn.Module):\n    \"\"\"\n    A thin wrapper around ConditionalCrossAttention.\n    Applies LayerNorm to the conditioning input `y` before cross-attention.\n    \"\"\"\n    def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):\n        super().__init__()\n        self.y_norm = nn.LayerNorm(kv_dim, eps=eps)\n        self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)\n        self.pooled_adaln = pooled_adaln\n        if pooled_adaln:\n            self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)\n            self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        y: torch.Tensor,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n    ) -> torch.Tensor:\n        if self.pooled_adaln:\n            assert video_grid_size is not None, \"video_grid_size must not be None\"\n            pooled_y = self.per_frame_pooling(y, video_grid_size)\n            # Interpolate pooled_y along its temporal dimension to match x's sequence length.\n            if pooled_y.shape[1] != x.shape[1]:\n                pooled_y = F.interpolate(\n                    pooled_y.permute(0, 2, 1),  # [B, C, T]\n                    size=x.shape[1],\n                    mode='linear',\n                    align_corners=False,\n                ).permute(0, 2, 1)  # [B, T, C]\n            x = self.adaln(x, temb=pooled_y)\n        y = self.y_norm(y)\n        return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)\n\n\nclass DualTowerConditionalBridge(nn.Module):\n    \"\"\"\n    Dual-tower conditional bridge.\n    \"\"\"\n    def __init__(self,\n                 visual_layers: int = 40,\n                 audio_layers: int = 30,\n                 visual_hidden_dim: int = 5120,    # visual DiT hidden state dimension\n                 audio_hidden_dim: int = 1536,     # audio DiT hidden state dimension\n                 audio_fps: float = 50.0,\n                 head_dim: int = 128,              # attention head dimension\n                 interaction_strategy: str = \"full\",\n                 apply_cross_rope: bool = True,   # whether to apply RoPE in cross-attention\n                 apply_first_frame_bias_in_rope: bool = False,  # whether to account for 1/video_fps bias for the first frame in RoPE alignment\n                 trainable_condition_scale: bool = False,\n                 pooled_adaln: bool = False,\n                 ):\n        super().__init__()\n\n        self.visual_hidden_dim = visual_hidden_dim\n        self.audio_hidden_dim = audio_hidden_dim\n        self.audio_fps = audio_fps\n        self.head_dim = head_dim\n        self.apply_cross_rope = apply_cross_rope\n        self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope\n        self.trainable_condition_scale = trainable_condition_scale\n        self.pooled_adaln = pooled_adaln\n        if self.trainable_condition_scale:\n            self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))\n        else:\n            self.condition_scale = 1.0\n\n        self.controller = CrossModalInteractionController(visual_layers, audio_layers)\n        self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)\n\n        # Conditional cross-attention modules operating at the DiT hidden-state level.\n        self.audio_to_video_conditioners = nn.ModuleDict()  # audio hidden states -> visual DiT conditioning\n        self.video_to_audio_conditioners = nn.ModuleDict()  # visual hidden states -> audio DiT conditioning\n\n        # Build conditioners for layers that should interact.\n        # audio hidden states condition the visual DiT\n        self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)\n        for v_layer, _ in self.interaction_mapping['a2v']:\n            self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(\n                dim=visual_hidden_dim,     # 3072 (visual DiT hidden states)\n                kv_dim=audio_hidden_dim,    # 1536 (audio DiT hidden states)\n                num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim\n                pooled_adaln=False # a2v typically does not need pooled AdaLN\n            )\n\n        # visual hidden states condition the audio DiT\n        for a_layer, _ in self.interaction_mapping['v2a']:\n            self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(\n                dim=audio_hidden_dim,      # 1536 (audio DiT hidden states)\n                kv_dim=visual_hidden_dim,   # 3072 (visual DiT hidden states)\n                num_heads=audio_hidden_dim // head_dim, # safe head count derivation\n                pooled_adaln=self.pooled_adaln\n            )\n\n    @torch.no_grad()\n    def build_aligned_freqs(self,\n                            video_fps: float,\n                            grid_size: Tuple[int, int, int],\n                            audio_steps: int,\n                            device: Optional[torch.device] = None,\n                            dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),\n        and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).\n\n        Returns:\n            visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]\n            audio_freqs:  (cos_a, sin_a), shape [1, audio_steps, head_dim]\n        \"\"\"\n        f_v, h, w = grid_size\n        L_v = f_v * h * w\n        L_a = int(audio_steps)\n\n        device = device or next(self.parameters()).device\n        dtype = dtype or torch.float32\n\n        # Audio positions: 0,1,2,...,L_a-1 (audio as reference).\n        audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)\n\n        # Video positions: align video frames to audio-step units.\n        # FIXME(dhyu): hard-coded VAE temporal stride = 4\n        if self.apply_first_frame_bias_in_rope:\n            # Account for the \"first frame lasts 1/video_fps\" bias.\n            video_effective_fps = float(video_fps) / 4.0\n            if f_v > 0:\n                t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)\n                if f_v > 1:\n                    t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)\n            else:\n                t_starts = torch.zeros((0,), device=device, dtype=torch.float32)\n            # Convert to audio-step units.\n            video_pos_per_frame = t_starts * float(self.audio_fps)\n        else:\n            # No first-frame bias: uniform alignment.\n            scale = float(self.audio_fps) / float(video_fps / 4.0)\n            video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale\n        # Flatten to f*h*w; tokens within the same frame share the same time position.\n        video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)\n\n        # print(f\"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}\")\n        # print(f\"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}\")\n\n        # Build dummy x to produce cos/sin, dim=head_dim.\n        dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)\n        dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)\n\n        cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)\n        cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)\n\n        return (cos_v, sin_v), (cos_a, sin_a)\n\n    def should_interact(self, layer_idx: int, direction: str) -> bool:\n        return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)\n\n    def apply_conditional_control(\n        self,\n        layer_idx: int,\n        direction: str,\n        primary_hidden_states: torch.Tensor,\n        condition_hidden_states: torch.Tensor,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        condition_scale: Optional[float] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n        use_gradient_checkpointing: Optional[bool] = False,\n        use_gradient_checkpointing_offload: Optional[bool] = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        Apply conditional control (at the DiT hidden-state level).\n\n        Args:\n            layer_idx: current layer index\n            direction: conditioning direction\n                - 'a2v': audio hidden states -> visual DiT\n                - 'v2a': visual hidden states -> audio DiT\n            primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]\n            condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]\n            condition_scale: conditioning strength (similar to CFG scale)\n\n        Returns:\n            Conditioned primary DiT hidden states [B, L, hidden_dim]\n        \"\"\"\n\n        if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):\n            return primary_hidden_states\n\n        if direction == 'a2v':\n            # audio hidden states condition the visual DiT\n            conditioner = self.audio_to_video_conditioners[str(layer_idx)]\n\n        elif direction == 'v2a':\n            # visual hidden states condition the audio DiT\n            conditioner = self.video_to_audio_conditioners[str(layer_idx)]\n        else:\n            raise ValueError(f\"Invalid direction: {direction}\")\n\n        conditioned_features = gradient_checkpoint_forward(\n            conditioner,\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            x=primary_hidden_states,\n            y=condition_hidden_states,\n            x_freqs=x_freqs,\n            y_freqs=y_freqs,\n            video_grid_size=video_grid_size,\n        )\n\n        if self.trainable_condition_scale and condition_scale is not None:\n            print(\n                \"[WARN] This model has a trainable condition_scale, but an external \"\n                f\"condition_scale={condition_scale} was provided. The trainable condition_scale \"\n                \"will be ignored in favor of the external value.\"\n            )\n\n        scale = condition_scale if condition_scale is not None else self.condition_scale\n\n        primary_hidden_states = primary_hidden_states + conditioned_features * scale\n\n        return primary_hidden_states\n\n    def forward(\n        self,\n        layer_idx: int,\n        visual_hidden_states: torch.Tensor,\n        audio_hidden_states: torch.Tensor,\n        *,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        a2v_condition_scale: Optional[float] = None,\n        v2a_condition_scale: Optional[float] = None,\n        condition_scale: Optional[float] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n        use_gradient_checkpointing: Optional[bool] = False,\n        use_gradient_checkpointing_offload: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Apply bidirectional conditional control to both visual/audio towers.\n\n        Args:\n            layer_idx: current layer index\n            visual_hidden_states: visual DiT hidden states\n            audio_hidden_states: audio DiT hidden states\n            x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.\n                If provided, x_freqs is assumed to correspond to the primary tower and y_freqs\n                to the conditioning tower.\n            a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)\n            v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)\n            condition_scale: fallback conditioning strength when per-direction scale is None\n            video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled\n\n        Returns:\n            (visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.\n        \"\"\"\n\n        visual_conditioned = self.apply_conditional_control(\n            layer_idx=layer_idx,\n            direction=\"a2v\",\n            primary_hidden_states=visual_hidden_states,\n            condition_hidden_states=audio_hidden_states,\n            x_freqs=x_freqs,\n            y_freqs=y_freqs,\n            condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,\n            video_grid_size=video_grid_size,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n\n        audio_conditioned = self.apply_conditional_control(\n            layer_idx=layer_idx,\n            direction=\"v2a\",\n            primary_hidden_states=audio_hidden_states,\n            condition_hidden_states=visual_hidden_states,\n            x_freqs=y_freqs,\n            y_freqs=x_freqs,\n            condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,\n            video_grid_size=video_grid_size,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n\n        return visual_conditioned, audio_conditioned\n"
  },
  {
    "path": "diffsynth/models/nexus_gen.py",
    "content": "import torch\nfrom PIL import Image\n\n\nclass NexusGenAutoregressiveModel(torch.nn.Module):\n    def __init__(self, max_length=1024, max_pixels=262640):\n        super(NexusGenAutoregressiveModel, self).__init__()\n        from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration\n        from transformers import Qwen2_5_VLConfig\n        self.max_length = max_length\n        self.max_pixels = max_pixels\n        model_config = Qwen2_5_VLConfig(**{\n            \"_name_or_path\": \"DiffSynth-Studio/Nexus-GenV2\",\n            \"architectures\": [\n                \"Qwen2_5_VLForConditionalGeneration\"\n            ],\n            \"attention_dropout\": 0.0,\n            \"auto_map\": {\n                \"AutoConfig\": \"configuration_qwen2_5_vl.Qwen2_5_VLConfig\",\n                \"AutoModel\": \"modeling_qwen2_5_vl.Qwen2_5_VLModel\",\n                \"AutoModelForCausalLM\": \"modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration\"\n            },\n            \"bos_token_id\": 151643,\n            \"eos_token_id\": 151645,\n            \"hidden_act\": \"silu\",\n            \"hidden_size\": 3584,\n            \"image_token_id\": 151655,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 18944,\n            \"max_position_embeddings\": 128000,\n            \"max_window_layers\": 28,\n            \"model_type\": \"qwen2_5_vl\",\n            \"num_attention_heads\": 28,\n            \"num_hidden_layers\": 28,\n            \"num_key_value_heads\": 4,\n            \"pad_token_id\": 151643,\n            \"rms_norm_eps\": 1e-06,\n            \"rope_scaling\": {\n                \"mrope_section\": [\n                16,\n                24,\n                24\n                ],\n                \"rope_type\": \"default\",\n                \"type\": \"default\"\n            },\n            \"rope_theta\": 1000000.0,\n            \"sliding_window\": 32768,\n            \"tie_word_embeddings\": False,\n            \"torch_dtype\": \"bfloat16\",\n            \"transformers_version\": \"4.49.0\",\n            \"use_cache\": False,\n            \"use_sliding_window\": False,\n            \"video_token_id\": 151656,\n            \"vision_config\": {\n                \"hidden_size\": 1280,\n                \"in_chans\": 3,\n                \"model_type\": \"qwen2_5_vl\",\n                \"spatial_patch_size\": 14,\n                \"tokens_per_second\": 2,\n                \"torch_dtype\": \"bfloat16\"\n            },\n            \"vision_end_token_id\": 151653,\n            \"vision_start_token_id\": 151652,\n            \"vision_token_id\": 151654,\n            \"vocab_size\": 152064\n        })\n        self.model = Qwen2_5_VLForConditionalGeneration(model_config)\n        self.processor = None\n        \n        \n    def load_processor(self, path):\n        from .nexus_gen_ar_model import Qwen2_5_VLProcessor\n        self.processor = Qwen2_5_VLProcessor.from_pretrained(path)\n\n\n    @staticmethod\n    def state_dict_converter():\n        return NexusGenAutoregressiveModelStateDictConverter()\n\n    def bound_image(self, image, max_pixels=262640):\n        from qwen_vl_utils import smart_resize\n        resized_height, resized_width = smart_resize(\n            image.height,\n            image.width,\n            max_pixels=max_pixels,\n        )\n        return image.resize((resized_width, resized_height))\n\n    def get_editing_msg(self, instruction):\n        if '<image>' not in instruction:\n            instruction = '<image> ' + instruction\n        messages = [{\"role\":\"user\", \"content\":instruction}, {\"role\":\"assistant\", \"content\":\"Here is the image: <image>\"}]\n        return messages\n\n    def get_generation_msg(self, instruction):\n        instruction = \"Generate an image according to the following description: {}\".format(instruction)\n        messages = [{\"role\":\"user\", \"content\":instruction}, {\"role\":\"assistant\", \"content\":\"Here is an image based on the description: <image>\"}]\n        return messages\n\n    def forward(self, instruction, ref_image=None, num_img_tokens=81):\n        \"\"\"\n        Generate target embeddings for the given instruction and reference image.\n        \"\"\"\n        if ref_image is not None:\n            messages = self.get_editing_msg(instruction)\n            images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]\n            output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)\n        else:\n            messages = self.get_generation_msg(instruction)\n            images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]\n            output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)\n\n        return output_image_embeddings\n\n    def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):\n        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n        text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')\n        inputs = processor(\n            text=[text],\n            images=images,\n            padding=True,\n            return_tensors=\"pt\",\n        )\n        inputs = inputs.to(model.device)\n\n        input_embeds = model.model.embed_tokens(inputs['input_ids'])\n        image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])\n        ground_truth_image_embeds = image_embeds[-num_img_tokens:]\n        input_image_embeds = image_embeds[:-num_img_tokens]\n\n        image_mask = inputs['input_ids'] == model.config.image_token_id\n        indices = image_mask.cumsum(dim=1)\n        input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)\n        gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)\n        input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)\n        input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)\n\n        image_prefill_embeds = model.image_prefill_embeds(\n            torch.arange(81, device=model.device).long()\n        )\n        input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)\n\n        position_ids, _ = model.get_rope_index(\n            inputs['input_ids'],\n            inputs['image_grid_thw'],\n            attention_mask=inputs['attention_mask'])\n        position_ids = position_ids.contiguous()\n        outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)\n        output_image_embeddings = outputs.image_embeddings[:, :-1, :]\n        output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]\n        return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']\n\n\nclass NexusGenAutoregressiveModelStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_civitai(self, state_dict):\n        state_dict = {\"model.\" + key: value for key, value in state_dict.items()}\n        return state_dict\n"
  },
  {
    "path": "diffsynth/models/nexus_gen_ar_model.py",
    "content": "import os\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom transformers.cache_utils import Cache\nfrom transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput\nfrom transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings\nfrom transformers.modeling_outputs import ModelOutput\nfrom transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n    Qwen2_5_VisionTransformerPretrainedModel,\n    Qwen2_5_VLModel,\n    Qwen2_5_VLPreTrainedModel,\n    QWEN2_5_VL_INPUTS_DOCSTRING,\n    )\n\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.image_utils import ImageInput, VideoInput\nfrom transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs\nfrom transformers.tokenization_utils_base import PreTokenizedInput, TextInput\n\nGenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"Qwen2_5_VLConfig\"\n\n\n@dataclass\nclass Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):\n    \"\"\"\n    Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.\n\n    Args:\n        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):\n            Language modeling loss (for next-token prediction).\n        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):\n            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see\n            `past_key_values` input) to speed up sequential decoding.\n        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):\n            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +\n            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.\n\n            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.\n        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):\n            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n            sequence_length)`.\n\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n        rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):\n            The rope index difference between sequence length and multimodal rope.\n    \"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    logits: torch.FloatTensor = None\n    image_embeddings: torch.FloatTensor = None\n    past_key_values: Optional[List[torch.FloatTensor]] = None\n    hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n    rope_deltas: Optional[torch.LongTensor] = None\n\n\nclass Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    config_class = Qwen2_5_VLConfig\n    _no_split_modules = [\"Qwen2_5_VLDecoderLayer\", \"Qwen2_5_VLVisionBlock\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)\n        self.model = Qwen2_5_VLModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.rope_deltas = None  # cache rope_deltas here\n        self.image_prefill_embeds = nn.Embedding(81, config.hidden_size)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    def get_rope_index(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Calculate the 3D rope index based on image and video's temporal, height and width in LLM.\n\n        Explanation:\n            Each embedding sequence contains vision embedding and text embedding or just contains text embedding.\n\n            For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.\n            Examples:\n                input_ids: [T T T T T], here T is for text.\n                temporal position_ids: [0, 1, 2, 3, 4]\n                height position_ids: [0, 1, 2, 3, 4]\n                width position_ids: [0, 1, 2, 3, 4]\n\n            For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part\n            and 1D rotary position embedding for text part.\n            Examples:\n                Temporal (Time): 3 patches, representing different segments of the video in time.\n                Height: 2 patches, dividing each frame vertically.\n                Width: 2 patches, dividing each frame horizontally.\n                We also have some important parameters:\n                fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.\n                tokens_per_second: This is a crucial parameter. It dictates how many \"time-steps\" or \"temporal tokens\" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.\n                temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.\n                interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.\n                input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.\n                vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]\n                vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]\n                vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]\n                text temporal position_ids: [101, 102, 103, 104, 105]\n                text height position_ids: [101, 102, 103, 104, 105]\n                text width position_ids: [101, 102, 103, 104, 105]\n                Here we calculate the text start position_ids as the max vision position_ids plus 1.\n\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n                it.\n            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n                The temporal, height and width of feature shape of each image in LLM.\n            video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n                The temporal, height and width of feature shape of each video in LLM.\n            second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):\n                The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n        Returns:\n            position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)\n            mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)\n        \"\"\"\n        spatial_merge_size = self.config.vision_config.spatial_merge_size\n        image_token_id = self.config.image_token_id\n        video_token_id = self.config.video_token_id\n        vision_start_token_id = self.config.vision_start_token_id\n        mrope_position_deltas = []\n        if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n            total_input_ids = input_ids\n            if attention_mask is None:\n                attention_mask = torch.ones_like(total_input_ids)\n            position_ids = torch.ones(\n                3,\n                input_ids.shape[0],\n                input_ids.shape[1],\n                dtype=input_ids.dtype,\n                device=input_ids.device,\n            )\n            image_index, video_index = 0, 0\n            attention_mask = attention_mask.to(total_input_ids.device)\n            for i, input_ids in enumerate(total_input_ids):\n                input_ids = input_ids[attention_mask[i] == 1]\n                image_nums, video_nums = 0, 0\n                vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)\n                vision_tokens = input_ids[vision_start_indices + 1]\n                image_nums = (vision_tokens == image_token_id).sum()\n                video_nums = (vision_tokens == video_token_id).sum()\n                input_tokens = input_ids.tolist()\n                llm_pos_ids_list: list = []\n                st = 0\n                remain_images, remain_videos = image_nums, video_nums\n                for _ in range(image_nums + video_nums):\n                    if image_token_id in input_tokens and remain_images > 0:\n                        ed_image = input_tokens.index(image_token_id, st)\n                    else:\n                        ed_image = len(input_tokens) + 1\n                    if video_token_id in input_tokens and remain_videos > 0:\n                        ed_video = input_tokens.index(video_token_id, st)\n                    else:\n                        ed_video = len(input_tokens) + 1\n                    if ed_image < ed_video:\n                        t, h, w = (\n                            image_grid_thw[image_index][0],\n                            image_grid_thw[image_index][1],\n                            image_grid_thw[image_index][2],\n                        )\n                        second_per_grid_t = 0\n                        image_index += 1\n                        remain_images -= 1\n                        ed = ed_image\n\n                    else:\n                        t, h, w = (\n                            video_grid_thw[video_index][0],\n                            video_grid_thw[video_index][1],\n                            video_grid_thw[video_index][2],\n                        )\n                        if second_per_grid_ts is not None:\n                            second_per_grid_t = second_per_grid_ts[video_index]\n                        else:\n                            second_per_grid_t = 1.0\n                        video_index += 1\n                        remain_videos -= 1\n                        ed = ed_video\n                    llm_grid_t, llm_grid_h, llm_grid_w = (\n                        t.item(),\n                        h.item() // spatial_merge_size,\n                        w.item() // spatial_merge_size,\n                    )\n                    text_len = ed - st\n\n                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n                    range_tensor = torch.arange(llm_grid_t).view(-1, 1)\n                    expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)\n\n                    time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second\n\n                    time_tensor_long = time_tensor.long()\n                    t_index = time_tensor_long.flatten()\n\n                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n                    llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n                    st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n                if st < len(input_tokens):\n                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                    text_len = len(input_tokens) - st\n                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n                llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n                position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)\n                mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))\n            mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)\n            return position_ids, mrope_position_deltas\n        else:\n            if attention_mask is not None:\n                position_ids = attention_mask.long().cumsum(-1) - 1\n                position_ids.masked_fill_(attention_mask == 0, 1)\n                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)\n                max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]\n                mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]\n            else:\n                position_ids = (\n                    torch.arange(input_ids.shape[1], device=input_ids.device)\n                    .view(1, 1, -1)\n                    .expand(3, input_ids.shape[0], -1)\n                )\n                mrope_position_deltas = torch.zeros(\n                    [input_ids.shape[0], 1],\n                    device=input_ids.device,\n                    dtype=input_ids.dtype,\n                )\n\n            return position_ids, mrope_position_deltas\n\n    @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        pixel_values: Optional[torch.Tensor] = None,\n        pixel_values_videos: Optional[torch.FloatTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        rope_deltas: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        image_embeddings: Optional[torch.Tensor] = None,\n        token_loss_weight: Optional[float] = 0.1,\n        img_loss_weight: Optional[float] = 1.0,\n    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n\n        >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\"Qwen/Qwen2.5-VL-7B-Instruct\")\n        >>> processor = AutoProcessor.from_pretrained(\"Qwen/Qwen2.5-VL-7B-Instruct\")\n\n        >>> messages = [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\"},\n                    {\"type\": \"text\", \"text\": \"What is shown in this image?\"},\n                ],\n            },\n        ]\n        >>> url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n        >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ...\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if inputs_embeds is None:\n            # test feature\n            inputs_embeds = self.model.embed_tokens(input_ids)\n            # for image encoding and training\n            if pixel_values is not None:\n                pixel_values = pixel_values.type(self.visual.dtype)\n                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n                n_image_tokens = (input_ids == self.config.image_token_id).sum().item()\n                n_image_features = image_embeds.shape[0]\n                if n_image_tokens != n_image_features:\n                    raise ValueError(\n                        f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\"\n                    )\n\n                mask = input_ids == self.config.image_token_id\n                mask_unsqueezed = mask.unsqueeze(-1)\n                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n                image_mask = mask_expanded.to(inputs_embeds.device)\n\n                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n            if pixel_values_videos is not None:\n                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)\n                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()\n                n_video_features = video_embeds.shape[0]\n                if n_video_tokens != n_video_features:\n                    raise ValueError(\n                        f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\"\n                    )\n\n                mask = input_ids == self.config.video_token_id\n                mask_unsqueezed = mask.unsqueeze(-1)\n                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n                video_mask = mask_expanded.to(inputs_embeds.device)\n\n                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(inputs_embeds.device)\n\n        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme\n        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):\n            # calculate RoPE index once per generation in the pre-fill stage only\n            if (\n                (cache_position is not None and cache_position[0] == 0)\n                or self.rope_deltas is None\n                or (past_key_values is None or past_key_values.get_seq_length() == 0)\n            ):\n                position_ids, rope_deltas = self.get_rope_index(\n                    input_ids,\n                    image_grid_thw,\n                    video_grid_thw,\n                    second_per_grid_ts,\n                    attention_mask,\n                )\n                self.rope_deltas = rope_deltas\n            # then use the prev pre-calculated rope-deltas to get the correct position ids\n            else:\n                batch_size, seq_length, _ = inputs_embeds.shape\n                delta = (\n                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)\n                    if cache_position is not None\n                    else 0\n                )\n                position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n                position_ids = position_ids.view(1, -1).expand(batch_size, -1)\n                if cache_position is not None:  # otherwise `deltas` is an int `0`\n                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)\n                position_ids = position_ids.add(delta)\n                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)\n        # position_ids [3, B, L]\n\n        outputs = self.model(\n            input_ids=None,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        image_embeds = self.vision_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Upcast to float if we need to compute the loss to avoid potential precision issues\n            # prepare labels for logits\n            logits_labels = labels.clone().detach()\n            image_tokens = (labels == self.config.image_token_id)\n            logits_labels[image_tokens] = -100\n\n            logits = logits.float()\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = logits_labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels) * token_loss_weight\n\n            shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1)\n            shifted_image_embeds = image_embeds[:, :-1, :].contiguous()  # (B, L-1, D)\n            masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d]  # (num_image_tokens, D)\n\n            mse_loss_fct = nn.MSELoss()\n            mse_loss_fct = mse_loss_fct.to(shift_logits.device)\n            if image_embeddings is None:\n                image_embeddings = torch.zeros_like(masked_image_embeds)\n            img_loss = mse_loss_fct(masked_image_embeds, image_embeddings)\n\n            cos_sim = torch.cosine_similarity(\n                masked_image_embeds,\n                image_embeddings,\n                dim=-1\n            )\n            cos_loss = (1 - cos_sim).mean()\n            img_loss = 0.5 * img_loss + 0.5 * cos_loss\n            # fix nan for empty image tokens\n            if image_embeddings.size(0) == 0:\n                img_loss = img_loss.nan_to_num(0.0)\n            # combine the loss\n            loss = loss + img_loss_weight * img_loss\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return Qwen2_5_VLCausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            image_embeddings=image_embeds,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            rope_deltas=self.rope_deltas,\n        )\n\n\n\n    def _sample(\n        self,\n        input_ids: torch.LongTensor,\n        logits_processor: LogitsProcessorList,\n        stopping_criteria: StoppingCriteriaList,\n        generation_config: GenerationConfig,\n        synced_gpus: bool,\n        streamer: Optional[\"BaseStreamer\"],\n        **model_kwargs,\n    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:\n        r\"\"\"\n        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and\n        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.\n\n        Parameters:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                The sequence used as a prompt for the generation.\n            logits_processor (`LogitsProcessorList`):\n                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]\n                used to modify the prediction scores of the language modeling head applied at each generation step.\n            stopping_criteria (`StoppingCriteriaList`):\n                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]\n                used to tell if the generation loop should stop.\n            generation_config ([`~generation.GenerationConfig`]):\n                The generation configuration to be used as parametrization of the decoding method.\n            synced_gpus (`bool`):\n                Whether to continue running the while loop until max_length (needed to avoid deadlocking with\n                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).\n            streamer (`BaseStreamer`, *optional*):\n                Streamer object that will be used to stream the generated sequences. Generated tokens are passed\n                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.\n            model_kwargs:\n                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is\n                an encoder-decoder model the kwargs should include `encoder_outputs`.\n\n        Return:\n            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:\n            A `torch.LongTensor` containing the generated tokens (default behaviour) or a\n            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and\n            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if\n            `model.config.is_encoder_decoder=True`.\n        \"\"\"\n        # init values\n        pad_token_id = generation_config._pad_token_tensor\n        output_attentions = generation_config.output_attentions\n        output_hidden_states = generation_config.output_hidden_states\n        output_scores = generation_config.output_scores\n        output_logits = generation_config.output_logits\n        return_dict_in_generate = generation_config.return_dict_in_generate\n        max_length = generation_config.max_length\n        has_eos_stopping_criteria = any(hasattr(criteria, \"eos_token_id\") for criteria in stopping_criteria)\n        do_sample = generation_config.do_sample\n\n        # init attention / hidden states / scores tuples\n        scores = () if (return_dict_in_generate and output_scores) else None\n        raw_logits = () if (return_dict_in_generate and output_logits) else None\n        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None\n        cross_attentions = () if (return_dict_in_generate and output_attentions) else None\n        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None\n\n        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states\n        if return_dict_in_generate and self.config.is_encoder_decoder:\n            encoder_attentions = model_kwargs[\"encoder_outputs\"].get(\"attentions\") if output_attentions else None\n            encoder_hidden_states = (\n                model_kwargs[\"encoder_outputs\"].get(\"hidden_states\") if output_hidden_states else None\n            )\n\n        # keep track of which sequences are already finished\n        batch_size, cur_len = input_ids.shape\n        this_peer_finished = False\n        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)\n        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)\n\n        model_forward = self.__call__\n        if isinstance(model_kwargs.get(\"past_key_values\"), Cache):\n            is_compileable = model_kwargs[\"past_key_values\"].is_compileable and self._supports_static_cache\n            is_compileable = is_compileable and not self.generation_config.disable_compile\n            if is_compileable and (\n                self.device.type in [\"cuda\", \"npu\"] or generation_config.compile_config._compile_all_devices\n            ):\n                os.environ[\"TOKENIZERS_PARALLELISM\"] = \"0\"\n                model_forward = self.get_compiled_call(generation_config.compile_config)\n\n        is_prefill = True\n        is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id\n        generation_image_grid_thw = model_kwargs.pop(\"generation_image_grid_thw\", self.get_default_image_grid_thw())\n        num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw)\n        output_image_embeddings = []\n        while self._has_unfinished_sequences(\n            this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length\n        ):\n            # prepare model inputs\n            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n\n            # prepare prefilled embeds\n            model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs))\n\n            # parse position_ids from model_kwargs\n            model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs))\n\n            # prepare variable output controls (note: some models won't accept all output controls)\n            model_inputs.update({\"output_attentions\": output_attentions} if output_attentions else {})\n            model_inputs.update({\"output_hidden_states\": output_hidden_states} if output_hidden_states else {})\n\n            if is_prefill:\n                outputs = self(**model_inputs, return_dict=True)\n                is_prefill = False\n            else:\n                outputs = model_forward(**model_inputs, return_dict=True)\n\n            # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\n            model_kwargs = self._update_model_kwargs_for_generation(\n                outputs,\n                model_kwargs,\n                is_encoder_decoder=self.config.is_encoder_decoder,\n            )\n            # TODO: support batch image sampling\n            if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens:\n                output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1))\n\n            if synced_gpus and this_peer_finished:\n                continue\n            # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration\n            # (the clone itself is always small)\n            next_token_logits = outputs.logits[:, -1, :].clone().float()\n            next_token_logits = next_token_logits.to(input_ids.device)\n\n            # do not sample <vision_end> token\n            next_token_logits[:, self.config.vision_end_token_id] = -float('inf')\n            # pre-process distribution\n            next_token_scores = logits_processor(input_ids, next_token_logits)\n            # Store scores, attentions and hidden_states when required\n            if return_dict_in_generate:\n                if output_scores:\n                    scores += (next_token_scores,)\n                if output_logits:\n                    raw_logits += (next_token_logits,)\n                if output_attentions:\n                    decoder_attentions += (\n                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)\n                    )\n                    if self.config.is_encoder_decoder:\n                        cross_attentions += (outputs.cross_attentions,)\n\n                if output_hidden_states:\n                    decoder_hidden_states += (\n                        (outputs.decoder_hidden_states,)\n                        if self.config.is_encoder_decoder\n                        else (outputs.hidden_states,)\n                    )\n\n            # token selection\n            if do_sample:\n                probs = nn.functional.softmax(next_token_scores, dim=-1)\n                # TODO (joao): this OP throws \"skipping cudagraphs due to ['incompatible ops']\", find solution\n                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n                # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id):\n                #     probs[:, self.config.vision_end_token_id] = 0\n                #     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n            else:\n                next_tokens = torch.argmax(next_token_scores, dim=-1)\n\n            # finished sentences should have their next token be a padding token\n            if has_eos_stopping_criteria:\n                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)\n\n            #TODO: support batch image sample\n            if num_img_tokens is not None:\n                cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)\n                # check whether is sampling images\n                is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img)\n                is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens)\n                next_tokens[is_sampling_img] = self.config.image_token_id\n                # check whether to end sampling images\n                next_tokens[is_end_img] = self.config.vision_end_token_id\n            else:\n                # check whether to end sampling images\n                is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id))\n                # replace the next token with the image token if is sampling image\n                next_tokens[is_sampling_img] = self.config.image_token_id\n            # check whether to start sampling images\n            is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id))\n\n            # update generated ids, model inputs, and length for next step\n            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n\n            if streamer is not None:\n                streamer.put(next_tokens.cpu())\n\n            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)\n            this_peer_finished = unfinished_sequences.max() == 0\n            cur_len += 1\n\n            # This is needed to properly delete outputs.logits which may be very large for first iteration\n            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration\n            del outputs\n\n        if streamer is not None:\n            streamer.end()\n\n        # output the image embeddings\n        output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None\n\n        if return_dict_in_generate:\n            return GenerateDecoderOnlyAll2AllOutput(\n                sequences=input_ids,\n                scores=scores,\n                logits=raw_logits,\n                attentions=decoder_attentions,\n                hidden_states=decoder_hidden_states,\n                past_key_values=model_kwargs.get(\"past_key_values\"),\n                output_image_embeddings=output_image_embeddings,\n            )\n        else:\n            return input_ids\n\n\n    def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs):\n        if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img):\n            return {}\n        # TODO: support batch image sample\n        image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0)\n        inputs_embeds = self.image_prefill_embeds(image_idx)\n        return {\"inputs_embeds\": inputs_embeds}\n\n\n    def get_default_image_grid_thw(self,):\n        return torch.tensor([[1, 18, 18]]).to(self.device)\n\n\n    def get_num_image_tokens(self, image_grid_thw):\n        return int(torch.prod(image_grid_thw, dim=1).sum() // 4)\n\n\n    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):\n        num_img_tokens = model_kwargs.pop(\"generation_image_grid_thw\", None)\n        super()._validate_model_kwargs(model_kwargs)\n        model_kwargs[\"generation_image_grid_thw\"] = num_img_tokens\n\n    def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs):\n        # Overwritten -- prepare position_ids for image tokens\n        cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1))\n        # TODO: support batch image sample\n        if cur_img_tokens > 0 and bool(is_sampling_img):\n            image_grid_thw = generation_image_grid_thw\n            if model_kwargs.get('image_grid_thw') is not None:\n                image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw])\n            remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens\n            padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id)\n            padded_ids = torch.cat([input_ids, padding_ids], dim=1)\n            position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None)\n            if model_kwargs.get(\"use_cache\", True):\n                position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1)\n            else:\n                position_ids = position_ids[:, :, :input_ids.shape[1]]\n            return {\"position_ids\": position_ids}\n        return {}\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        position_ids=None,\n        use_cache=True,\n        pixel_values=None,\n        pixel_values_videos=None,\n        image_grid_thw=None,\n        video_grid_thw=None,\n        second_per_grid_ts=None,\n        image_embeddings=None,\n        **kwargs,\n    ):\n        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model\n\n        model_inputs = super().prepare_inputs_for_generation(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            inputs_embeds=inputs_embeds,\n            cache_position=cache_position,\n            position_ids=position_ids,\n            pixel_values=pixel_values,\n            pixel_values_videos=pixel_values_videos,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            use_cache=use_cache,\n            **kwargs,\n        )\n\n        # Qwen2-5-VL position_ids are prepared with rope_deltas in forward\n        model_inputs[\"position_ids\"] = None\n\n        if cache_position[0] != 0:\n            model_inputs[\"pixel_values\"] = None\n            model_inputs[\"pixel_values_videos\"] = None\n        return model_inputs\n\n    def _get_image_nums_and_video_nums(\n        self,\n        input_ids: Optional[torch.LongTensor],\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Get the number of images and videos for each sample to calculate the separation length of the sample tensor.\n        These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.\n\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary.\n\n        Returns:\n            image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)\n            video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)\n        \"\"\"\n        image_token_id = self.config.image_token_id\n        video_token_id = self.config.video_token_id\n        vision_start_token_id = self.config.vision_start_token_id\n\n        vision_start_mask = input_ids == vision_start_token_id\n        vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)\n        image_mask = input_ids == image_token_id\n        video_mask = input_ids == video_token_id\n        image_nums = torch.sum(vision_first_mask & image_mask, dim=1)\n        video_nums = torch.sum(vision_first_mask & video_mask, dim=1)\n\n        return image_nums, video_nums\n\n    def _expand_inputs_for_generation(\n        self,\n        expand_size: int = 1,\n        is_encoder_decoder: bool = False,\n        input_ids: Optional[torch.LongTensor] = None,\n        **model_kwargs,\n    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:\n        # Overwritten -- Support for expanding tensors without a batch size dimension\n        # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t\n        # pixel_values.shape[0] is sum(seqlen_images for samples)\n        # image_grid_thw.shape[0] is sum(num_images for samples)\n\n        if expand_size == 1:\n            return input_ids, model_kwargs\n\n        visual_keys = [\"pixel_values\", \"image_grid_thw\", \"pixel_values_videos\", \"video_grid_thw\", \"second_per_grid_ts\"]\n\n        def _expand_dict_for_generation_visual(dict_to_expand):\n            image_grid_thw = model_kwargs.get(\"image_grid_thw\", None)\n            video_grid_thw = model_kwargs.get(\"video_grid_thw\", None)\n            image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)\n\n            def _repeat_interleave_samples(x, lengths, repeat_times):\n                samples = torch.split(x, lengths)\n                repeat_args = [repeat_times] + [1] * (x.dim() - 1)\n                result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)\n                return result\n\n            for key in dict_to_expand:\n                if key == \"pixel_values\":\n                    # split images into samples\n                    samples = torch.split(image_grid_thw, list(image_nums))\n                    # compute the sequence length of images for each sample\n                    lengths = [torch.prod(sample, dim=1).sum() for sample in samples]\n                    dict_to_expand[key] = _repeat_interleave_samples(\n                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size\n                    )\n                elif key == \"image_grid_thw\":\n                    # get the num of images for each sample\n                    lengths = list(image_nums)\n                    dict_to_expand[key] = _repeat_interleave_samples(\n                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size\n                    )\n                elif key == \"pixel_values_videos\":\n                    samples = torch.split(video_grid_thw, list(video_nums))\n                    lengths = [torch.prod(sample, dim=1).sum() for sample in samples]\n                    dict_to_expand[key] = _repeat_interleave_samples(\n                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size\n                    )\n                elif key == \"video_grid_thw\":\n                    lengths = list(video_nums)\n                    dict_to_expand[key] = _repeat_interleave_samples(\n                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size\n                    )\n                elif key == \"second_per_grid_ts\":\n                    if not isinstance(dict_to_expand[key], list):\n                        raise TypeError(\n                            f\"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead.\"\n                        )\n                    tensor = torch.tensor(dict_to_expand[key])\n                    lengths = list(video_nums)\n                    tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)\n                    dict_to_expand[key] = tensor.tolist()\n            return dict_to_expand\n\n        def _expand_dict_for_generation(dict_to_expand):\n            for key in dict_to_expand:\n                if (\n                    key != \"cache_position\"\n                    and dict_to_expand[key] is not None\n                    and isinstance(dict_to_expand[key], torch.Tensor)\n                    and key not in visual_keys\n                ):\n                    dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)\n            return dict_to_expand\n\n        # input_ids is required for expanding visual inputs\n        # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.\n        if input_ids is not None and input_ids.numel() != 0:\n            model_kwargs = _expand_dict_for_generation_visual(model_kwargs)\n\n        if input_ids is not None:\n            input_ids = input_ids.repeat_interleave(expand_size, dim=0)\n\n        model_kwargs = _expand_dict_for_generation(model_kwargs)\n\n        if is_encoder_decoder:\n            if model_kwargs.get(\"encoder_outputs\") is None:\n                raise ValueError(\"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.\")\n            model_kwargs[\"encoder_outputs\"] = _expand_dict_for_generation(model_kwargs[\"encoder_outputs\"])\n\n        return input_ids, model_kwargs\n\n\n__all__ = [\"Qwen2_5_VLForConditionalGeneration\", \"Qwen2_5_VLModel\", \"Qwen2_5_VLPreTrainedModel\"]\n\n\n\nclass Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):\n    fps: Union[List[float], float]\n\n\nclass Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):\n    videos_kwargs: Qwen2_5_VLVideosProcessorKwargs\n    _defaults = {\n        \"text_kwargs\": {\n            \"padding\": False,\n        },\n        \"videos_kwargs\": {\"fps\": 2.0},\n    }\n\n\nclass Qwen2_5_VLProcessor(ProcessorMixin):\n    r\"\"\"\n    Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.\n    [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the\n    [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.\n    Args:\n        image_processor ([`Qwen2VLImageProcessor`], *optional*):\n            The image processor is a required input.\n        tokenizer ([`Qwen2TokenizerFast`], *optional*):\n            The tokenizer is a required input.\n        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages\n            in a chat into a tokenizable string.\n    \"\"\"\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n    valid_kwargs = [\"chat_template\"]\n\n    image_processor_class = \"AutoImageProcessor\"\n    tokenizer_class = (\"Qwen2Tokenizer\", \"Qwen2TokenizerFast\")\n\n    def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):\n        self.image_token = \"<|image_pad|>\" if not hasattr(tokenizer, \"image_token\") else tokenizer.image_token\n        self.video_token = \"<|video_pad|>\" if not hasattr(tokenizer, \"video_token\") else tokenizer.video_token\n        super().__init__(image_processor, tokenizer, chat_template=chat_template)\n\n    def __call__(\n        self,\n        images: ImageInput = None,\n        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,\n        videos: VideoInput = None,\n        **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],\n    ) -> BatchFeature:\n        \"\"\"\n        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`\n        and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode\n        the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to\n        Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.\n\n        Args:\n            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch\n                tensor. Both channels-first and channels-last formats are supported.\n            text (`str`, `List[str]`, `List[List[str]]`):\n                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n            videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch\n                tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.\n            return_tensors (`str` or [`~utils.TensorType`], *optional*):\n                If set, will return tensors of a particular framework. Acceptable values are:\n                - `'tf'`: Return TensorFlow `tf.constant` objects.\n                - `'pt'`: Return PyTorch `torch.Tensor` objects.\n                - `'np'`: Return NumPy `np.ndarray` objects.\n                - `'jax'`: Return JAX `jnp.ndarray` objects.\n\n        Returns:\n            [`BatchFeature`]: A [`BatchFeature`] with the following fields:\n\n            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.\n            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when\n              `return_attention_mask=True` or if *\"attention_mask\"* is in `self.model_input_names` and if `text` is not\n              `None`).\n            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.\n            - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.\n            - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.\n            - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.\n            - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.\n        \"\"\"\n        output_kwargs = self._merge_kwargs(\n            Qwen2_5_VLProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer.init_kwargs,\n            **kwargs,\n        )\n        if images is not None:\n            image_inputs = self.image_processor(images=images, videos=None, **output_kwargs[\"images_kwargs\"])\n            image_grid_thw = image_inputs[\"image_grid_thw\"]\n        else:\n            image_inputs = {}\n            image_grid_thw = None\n\n        if videos is not None:\n            videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs[\"images_kwargs\"])\n            video_grid_thw = videos_inputs[\"video_grid_thw\"]\n\n            fps = output_kwargs[\"videos_kwargs\"].pop(\"fps\", 2.0)\n            if isinstance(fps, (int, float)):\n                second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)\n            elif hasattr(fps, \"__len__\") and len(fps) == len(video_grid_thw):\n                second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]\n            else:\n                raise ValueError(\n                    f\"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number.\"\n                )\n            videos_inputs.update({\"second_per_grid_ts\": second_per_grid_ts})\n\n        else:\n            videos_inputs = {}\n            video_grid_thw = None\n\n        if not isinstance(text, list):\n            text = [text]\n\n        if image_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.image_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.image_token,\n                        \"<|placeholder|>\" * (image_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.image_token)\n\n        if video_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.video_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.video_token,\n                        \"<|placeholder|>\" * (video_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.video_token)\n\n        text_inputs = self.tokenizer(text, **output_kwargs[\"text_kwargs\"])\n\n        return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})\n\n    def batch_decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.batch_decode(*args, **kwargs)\n\n    def batch_decode_all2all(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        decoded = self.tokenizer.batch_decode(*args, **kwargs)\n        pattern = r'<\\|vision_start\\|>.*?<\\|vision_end\\|>'\n        decoded_with_image_tag = [re.sub(pattern, '<image>', d, flags=re.DOTALL) for d in decoded]\n        decoded_with_image_tag = [re.sub(r'<\\|im_end\\|>', '', d) for d in decoded_with_image_tag]\n        return decoded_with_image_tag\n\n    def decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to\n        the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.decode(*args, **kwargs)\n\n    def post_process_image_text_to_text(\n        self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs\n    ):\n        \"\"\"\n        Post-process the output of the model to decode the text.\n\n        Args:\n            generated_outputs (`torch.Tensor` or `np.ndarray`):\n                The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`\n                or `(sequence_length,)`.\n            skip_special_tokens (`bool`, *optional*, defaults to `True`):\n                Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.\n            Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):\n                Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.\n            **kwargs:\n                Additional arguments to be passed to the tokenizer's `batch_decode method`.\n\n        Returns:\n            `List[str]`: The decoded text.\n        \"\"\"\n        return self.tokenizer.batch_decode(\n            generated_outputs,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    @property\n    def model_input_names(self):\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n        names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))\n        return names_from_processor + [\"second_per_grid_ts\"]\n\n\n__all__ = [\"Qwen2_5_VLProcessor\"]\n"
  },
  {
    "path": "diffsynth/models/nexus_gen_projector.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nfrom typing import Optional, Tuple\n\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):\n    mrope_section = mrope_section * 2\n    cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(\n        unsqueeze_dim\n    )\n    sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(\n        unsqueeze_dim\n    )\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass Qwen2_5_VLRotaryEmbedding(nn.Module):\n    def __init__(self, config, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        from transformers.modeling_rope_utils import _compute_default_rope_parameters\n        self.rope_init_fn = _compute_default_rope_parameters\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(\n                self.config, device, seq_len=seq_len, **self.rope_kwargs\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids\n        # So we expand the inv_freq to shape (3, ...)\n        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)\n        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Qwen2_5_VLAttention(nn.Module):\n    def __init__(self, config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n        self.rope_scaling = config.rope_scaling\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_multimodal_rotary_pos_emb(\n            query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"]\n        )\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        # Fix precision issues in Qwen2-VL float16 inference\n        # Replace inf values with zeros in attention weights to prevent NaN propagation\n        if query_states.dtype == torch.float16:\n            attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass Qwen2MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        from transformers.activations import ACT2FN\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass Qwen2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen2_5_VLDecoderLayer(nn.Module):\n    def __init__(self, config, layer_idx):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = Qwen2_5_VLAttention(config, layer_idx)\n\n        self.mlp = Qwen2MLP(config)\n        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_embeddings=position_embeddings,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass NexusGenImageEmbeddingMerger(nn.Module):\n    def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):\n        super().__init__()\n        from transformers import Qwen2_5_VLConfig\n        from transformers.activations import ACT2FN\n        config = Qwen2_5_VLConfig(**{\n            \"_name_or_path\": \"DiffSynth-Studio/Nexus-GenV2\",\n            \"architectures\": [\n                \"Qwen2_5_VLForConditionalGeneration\"\n            ],\n            \"attention_dropout\": 0.0,\n            \"auto_map\": {\n                \"AutoConfig\": \"configuration_qwen2_5_vl.Qwen2_5_VLConfig\",\n                \"AutoModel\": \"modeling_qwen2_5_vl.Qwen2_5_VLModel\",\n                \"AutoModelForCausalLM\": \"modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration\"\n            },\n            \"bos_token_id\": 151643,\n            \"eos_token_id\": 151645,\n            \"hidden_act\": \"silu\",\n            \"hidden_size\": 3584,\n            \"image_token_id\": 151655,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 18944,\n            \"max_position_embeddings\": 128000,\n            \"max_window_layers\": 28,\n            \"model_type\": \"qwen2_5_vl\",\n            \"num_attention_heads\": 28,\n            \"num_hidden_layers\": 28,\n            \"num_key_value_heads\": 4,\n            \"pad_token_id\": 151643,\n            \"rms_norm_eps\": 1e-06,\n            \"rope_scaling\": {\n                \"mrope_section\": [\n                16,\n                24,\n                24\n                ],\n                \"rope_type\": \"default\",\n                \"type\": \"default\"\n            },\n            \"rope_theta\": 1000000.0,\n            \"sliding_window\": 32768,\n            \"tie_word_embeddings\": False,\n            \"torch_dtype\": \"bfloat16\",\n            \"transformers_version\": \"4.49.0\",\n            \"use_cache\": False,\n            \"use_sliding_window\": False,\n            \"video_token_id\": 151656,\n            \"vision_config\": {\n                \"hidden_size\": 1280,\n                \"in_chans\": 3,\n                \"model_type\": \"qwen2_5_vl\",\n                \"spatial_patch_size\": 14,\n                \"tokens_per_second\": 2,\n                \"torch_dtype\": \"bfloat16\"\n            },\n            \"vision_end_token_id\": 151653,\n            \"vision_start_token_id\": 151652,\n            \"vision_token_id\": 151654,\n            \"vocab_size\": 152064\n        })\n        self.config = config\n        self.num_layers = num_layers\n        self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])\n        self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),\n                                       nn.Linear(config.hidden_size, out_channel * expand_ratio),\n                                       Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),\n                                       ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),\n                                       Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))\n        self.base_grid = torch.tensor([[1, 72, 72]], device=device)\n        self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)\n\n    def get_position_ids(self, image_grid_thw):\n        \"\"\"\n        Generates position ids for the input embeddings grid.\n        modified from the qwen2_vl mrope.\n        \"\"\"\n        batch_size = image_grid_thw.shape[0]\n        spatial_merge_size = self.config.vision_config.spatial_merge_size\n        t, h, w = (\n            image_grid_thw[0][0],\n            image_grid_thw[0][1],\n            image_grid_thw[0][2],\n        )\n        llm_grid_t, llm_grid_h, llm_grid_w = (\n            t.item(),\n            h.item() // spatial_merge_size,\n            w.item() // spatial_merge_size,\n        )\n        scale_h = self.base_grid[0][1].item() / h.item()\n        scale_w = self.base_grid[0][2].item() / w.item()\n\n        range_tensor = torch.arange(llm_grid_t).view(-1, 1)\n        expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)\n        time_tensor = expanded_range * self.config.vision_config.tokens_per_second\n        t_index = time_tensor.long().flatten().to(image_grid_thw.device)\n        h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h\n        w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w\n        # 3, B, L\n        position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)\n        return position_ids\n\n    def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):\n        position_ids = self.get_position_ids(embeds_grid)\n        hidden_states = embeds\n        if ref_embeds is not None:\n            position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)\n            position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)\n            hidden_states = torch.cat((embeds, ref_embeds), dim=1)\n\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n        for layer in self.layers:\n            hidden_states = layer(hidden_states, position_embeddings)\n\n        hidden_states = self.projector(hidden_states)\n        return hidden_states\n\n    @staticmethod\n    def state_dict_converter():\n        return NexusGenMergerStateDictConverter()\n\n\nclass NexusGenMergerStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        return state_dict\n    \n    def from_civitai(self, state_dict):\n        merger_state_dict = {key.replace(\"embedding_merger.\", \"\"): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}\n        return merger_state_dict\n\n\nclass NexusGenAdapter(nn.Module):\n    \"\"\"\n    Adapter for Nexus-Gen generation decoder.\n    \"\"\"\n    def __init__(self, input_dim=3584, output_dim=4096):\n        super(NexusGenAdapter, self).__init__()\n        self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),\n                                     nn.LayerNorm(output_dim), nn.ReLU(),\n                                     nn.Linear(output_dim, output_dim),\n                                     nn.LayerNorm(output_dim))\n\n    def forward(self, x):\n        return self.adapter(x)\n\n    @staticmethod\n    def state_dict_converter():\n        return NexusGenAdapterStateDictConverter()\n\n\nclass NexusGenAdapterStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        return state_dict\n    \n    def from_civitai(self, state_dict):\n        adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}\n        return adapter_state_dict\n"
  },
  {
    "path": "diffsynth/models/qwen_image_controlnet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .general_modules import RMSNorm\n\n\nclass BlockWiseControlBlock(torch.nn.Module):\n    # [linear, gelu, linear]\n    def __init__(self, dim: int = 3072):\n        super().__init__()\n        self.x_rms = RMSNorm(dim, eps=1e-6)\n        self.y_rms = RMSNorm(dim, eps=1e-6)\n        self.input_proj = nn.Linear(dim, dim)\n        self.act = nn.GELU()\n        self.output_proj = nn.Linear(dim, dim)\n\n    def forward(self, x, y):\n        x, y = self.x_rms(x), self.y_rms(y)\n        x = self.input_proj(x + y)\n        x = self.act(x)\n        x = self.output_proj(x)\n        return x\n\n    def init_weights(self):\n        # zero initialize output_proj\n        nn.init.zeros_(self.output_proj.weight)\n        nn.init.zeros_(self.output_proj.bias)\n\n\nclass QwenImageBlockWiseControlNet(torch.nn.Module):\n    def __init__(\n        self,\n        num_layers: int = 60,\n        in_dim: int = 64,\n        additional_in_dim: int = 0,\n        dim: int = 3072,\n    ):\n        super().__init__()\n        self.img_in = nn.Linear(in_dim + additional_in_dim, dim)\n        self.controlnet_blocks = nn.ModuleList(\n            [\n                BlockWiseControlBlock(dim)\n                for _ in range(num_layers)\n            ]\n        )\n\n    def init_weight(self):\n        nn.init.zeros_(self.img_in.weight)\n        nn.init.zeros_(self.img_in.bias)\n        for block in self.controlnet_blocks:\n            block.init_weights()\n\n    def process_controlnet_conditioning(self, controlnet_conditioning):\n        return self.img_in(controlnet_conditioning)\n\n    def blockwise_forward(self, img, controlnet_conditioning, block_id):\n        return self.controlnet_blocks[block_id](img, controlnet_conditioning)\n"
  },
  {
    "path": "diffsynth/models/qwen_image_dit.py",
    "content": "import torch, math, functools\nimport torch.nn as nn\nfrom typing import Tuple, Optional, Union, List\nfrom einops import rearrange\nfrom .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm\n\ntry:\n    import flash_attn_interface\n    FLASH_ATTN_3_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_3_AVAILABLE = False\n\n\ndef qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):\n    if FLASH_ATTN_3_AVAILABLE and attention_mask is None:\n        if not enable_fp8_attention:\n            q = rearrange(q, \"b n s d -> b s n d\", n=num_heads)\n            k = rearrange(k, \"b n s d -> b s n d\", n=num_heads)\n            v = rearrange(v, \"b n s d -> b s n d\", n=num_heads)\n            x = flash_attn_interface.flash_attn_func(q, k, v)\n            if isinstance(x, tuple):\n                x = x[0]\n            x = rearrange(x, \"b s n d -> b s (n d)\", n=num_heads)\n        else:\n            origin_dtype = q.dtype\n            q_std, k_std, v_std = q.std(), k.std(), v.std()\n            q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)\n            q = rearrange(q, \"b n s d -> b s n d\", n=num_heads)\n            k = rearrange(k, \"b n s d -> b s n d\", n=num_heads)\n            v = rearrange(v, \"b n s d -> b s n d\", n=num_heads)\n            x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))\n            if isinstance(x, tuple):\n                x = x[0]\n            x = x.to(origin_dtype) * v_std\n            x = rearrange(x, \"b s n d -> b s (n d)\", n=num_heads)\n    else:\n        x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)\n        x = rearrange(x, \"b n s d -> b s (n d)\", n=num_heads)\n    return x\n\n\nclass ApproximateGELU(nn.Module):\n    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out, bias=bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.proj(x)\n        return x * torch.sigmoid(1.702 * x)\n\ndef apply_rotary_emb_qwen(\n    x: torch.Tensor,\n    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]\n):\n    x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))\n    x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)\n    return x_out.type_as(x)\n\n\nclass QwenEmbedRope(nn.Module):\n    def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n        pos_index = torch.arange(4096)\n        neg_index = torch.arange(4096).flip(0) * -1 - 1\n        self.pos_freqs = torch.cat([\n            self.rope_params(pos_index, self.axes_dim[0], self.theta),\n            self.rope_params(pos_index, self.axes_dim[1], self.theta),\n            self.rope_params(pos_index, self.axes_dim[2], self.theta),\n        ], dim=1)\n        self.neg_freqs = torch.cat([\n            self.rope_params(neg_index, self.axes_dim[0], self.theta),\n            self.rope_params(neg_index, self.axes_dim[1], self.theta),\n            self.rope_params(neg_index, self.axes_dim[2], self.theta),\n        ], dim=1)\n        self.rope_cache = {}\n        self.scale_rope = scale_rope\n        \n    def rope_params(self, index, dim, theta=10000):\n        \"\"\"\n            Args:\n                index: [0, 1, 2, 3] 1D Tensor representing the position index of the token\n        \"\"\"\n        assert dim % 2 == 0\n        freqs = torch.outer(\n            index,\n            1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))\n        )\n        freqs = torch.polar(torch.ones_like(freqs), freqs)\n        return freqs\n\n\n    def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):\n        if isinstance(video_fhw, list):\n            video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))\n        _, height, width = video_fhw\n        if self.scale_rope:\n            max_vid_index = max(height // 2, width // 2)\n        else:\n            max_vid_index = max(height, width)\n        required_len = max_vid_index + max(txt_seq_lens)\n        cur_max_len = self.pos_freqs.shape[0]\n        if required_len <= cur_max_len:\n            return\n\n        new_max_len = math.ceil(required_len / 512) * 512\n        pos_index = torch.arange(new_max_len)\n        neg_index = torch.arange(new_max_len).flip(0) * -1 - 1\n        self.pos_freqs = torch.cat([\n            self.rope_params(pos_index, self.axes_dim[0], self.theta),\n            self.rope_params(pos_index, self.axes_dim[1], self.theta),\n            self.rope_params(pos_index, self.axes_dim[2], self.theta),\n        ], dim=1)\n        self.neg_freqs = torch.cat([\n            self.rope_params(neg_index, self.axes_dim[0], self.theta),\n            self.rope_params(neg_index, self.axes_dim[1], self.theta),\n            self.rope_params(neg_index, self.axes_dim[2], self.theta),\n        ], dim=1)\n        return\n\n\n    def forward(self, video_fhw, txt_seq_lens, device):\n        self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)\n        if self.pos_freqs.device != device:\n            self.pos_freqs = self.pos_freqs.to(device)\n            self.neg_freqs = self.neg_freqs.to(device)\n\n        vid_freqs = []\n        max_vid_index = 0\n        for idx, fhw in enumerate(video_fhw):\n            frame, height, width = fhw\n            rope_key = f\"{idx}_{height}_{width}\"\n\n            if rope_key not in self.rope_cache:\n                seq_lens = frame * height * width\n                freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n                freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n                freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n                if self.scale_rope:\n                    freqs_height = torch.cat(\n                        [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0\n                    )\n                    freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)\n                    freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)\n                    freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)\n\n                else:\n                    freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)\n                    freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)\n\n                freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)\n                self.rope_cache[rope_key] = freqs.clone().contiguous()\n            vid_freqs.append(self.rope_cache[rope_key])\n\n            if self.scale_rope:\n                max_vid_index = max(height // 2, width // 2, max_vid_index)\n            else:\n                max_vid_index = max(height, width, max_vid_index)\n\n        max_len = max(txt_seq_lens)\n        txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]\n        vid_freqs = torch.cat(vid_freqs, dim=0)\n\n        return vid_freqs, txt_freqs\n\n\n    def forward_sampling(self, video_fhw, txt_seq_lens, device):\n        self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)\n        if self.pos_freqs.device != device:\n            self.pos_freqs = self.pos_freqs.to(device)\n            self.neg_freqs = self.neg_freqs.to(device)\n\n        vid_freqs = []\n        max_vid_index = 0\n        for idx, fhw in enumerate(video_fhw):\n            frame, height, width = fhw\n            rope_key = f\"{idx}_{height}_{width}\"\n            if idx > 0 and f\"{0}_{height}_{width}\" not in self.rope_cache:\n                frame_0, height_0, width_0 = video_fhw[0]\n\n                rope_key_0 = f\"0_{height_0}_{width_0}\"\n                spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)\n                h_indices = torch.linspace(0, height_0 - 1, height).long()\n                w_indices = torch.linspace(0, width_0 - 1, width).long()\n                h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')\n                sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]\n\n                freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n                freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n                sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame\n\n                seq_lens = frame * height * width\n                self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()\n            if rope_key not in self.rope_cache:\n                seq_lens = frame * height * width\n                freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n                freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n                freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n                if self.scale_rope:\n                    freqs_height = torch.cat(\n                        [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0\n                    )\n                    freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)\n                    freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)\n                    freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)\n\n                else:\n                    freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)\n                    freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)\n\n                freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)\n                self.rope_cache[rope_key] = freqs.clone()\n            vid_freqs.append(self.rope_cache[rope_key].contiguous())\n\n            if self.scale_rope:\n                max_vid_index = max(height // 2, width // 2, max_vid_index)\n            else:\n                max_vid_index = max(height, width, max_vid_index)\n\n        max_len = max(txt_seq_lens)\n        txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]\n        vid_freqs = torch.cat(vid_freqs, dim=0)\n\n        return vid_freqs, txt_freqs\n\n\nclass QwenEmbedLayer3DRope(nn.Module):\n    def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n        pos_index = torch.arange(4096)\n        neg_index = torch.arange(4096).flip(0) * -1 - 1\n        self.pos_freqs = torch.cat(\n            [\n                self.rope_params(pos_index, self.axes_dim[0], self.theta),\n                self.rope_params(pos_index, self.axes_dim[1], self.theta),\n                self.rope_params(pos_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n        self.neg_freqs = torch.cat(\n            [\n                self.rope_params(neg_index, self.axes_dim[0], self.theta),\n                self.rope_params(neg_index, self.axes_dim[1], self.theta),\n                self.rope_params(neg_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n\n        self.scale_rope = scale_rope\n\n    def rope_params(self, index, dim, theta=10000):\n        \"\"\"\n        Args:\n            index: [0, 1, 2, 3] 1D Tensor representing the position index of the token\n        \"\"\"\n        assert dim % 2 == 0\n        freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))\n        freqs = torch.polar(torch.ones_like(freqs), freqs)\n        return freqs\n\n    def forward(self, video_fhw, txt_seq_lens, device):\n        \"\"\"\n        Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:\n        txt_length: [bs] a list of 1 integers representing the length of the text\n        \"\"\"\n        if self.pos_freqs.device != device:\n            self.pos_freqs = self.pos_freqs.to(device)\n            self.neg_freqs = self.neg_freqs.to(device)\n\n        video_fhw = [video_fhw]\n        if isinstance(video_fhw, list):\n            video_fhw = video_fhw[0]\n        if not isinstance(video_fhw, list):\n            video_fhw = [video_fhw]\n\n        vid_freqs = []\n        max_vid_index = 0\n        layer_num = len(video_fhw) - 1\n        for idx, fhw in enumerate(video_fhw):\n            frame, height, width = fhw\n            if idx != layer_num:\n                video_freq = self._compute_video_freqs(frame, height, width, idx)\n            else:\n                ### For the condition image, we set the layer index to -1\n                video_freq = self._compute_condition_freqs(frame, height, width)\n            video_freq = video_freq.to(device)\n            vid_freqs.append(video_freq)\n\n            if self.scale_rope:\n                max_vid_index = max(height // 2, width // 2, max_vid_index)\n            else:\n                max_vid_index = max(height, width, max_vid_index)\n\n        max_vid_index = max(max_vid_index, layer_num)\n        max_len = max(txt_seq_lens)\n        txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]\n        vid_freqs = torch.cat(vid_freqs, dim=0)\n\n        return vid_freqs, txt_freqs\n\n    @functools.lru_cache(maxsize=None)\n    def _compute_video_freqs(self, frame, height, width, idx=0):\n        seq_lens = frame * height * width\n        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n        freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n\n        freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n        if self.scale_rope:\n            freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)\n            freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)\n            freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)\n            freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)\n        else:\n            freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)\n            freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)\n\n        freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)\n        return freqs.clone().contiguous()\n\n    @functools.lru_cache(maxsize=None)\n    def _compute_condition_freqs(self, frame, height, width):\n        seq_lens = frame * height * width\n        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n        freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n\n        freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n        if self.scale_rope:\n            freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)\n            freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)\n            freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)\n            freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)\n        else:\n            freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)\n            freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)\n\n        freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)\n        return freqs.clone().contiguous()\n\n\nclass QwenFeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n        inner_dim = int(dim * 4)\n        self.net = nn.ModuleList([])\n        self.net.append(ApproximateGELU(dim, inner_dim))\n        self.net.append(nn.Dropout(dropout))\n        self.net.append(nn.Linear(inner_dim, dim_out))\n\n    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        for module in self.net:\n            hidden_states = module(hidden_states)\n        return hidden_states\n\nclass QwenDoubleStreamAttention(nn.Module):\n    def __init__(\n        self,\n        dim_a,\n        dim_b,\n        num_heads,\n        head_dim,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = nn.Linear(dim_a, dim_a)\n        self.to_k = nn.Linear(dim_a, dim_a)\n        self.to_v = nn.Linear(dim_a, dim_a)\n        self.norm_q = RMSNorm(head_dim, eps=1e-6)\n        self.norm_k = RMSNorm(head_dim, eps=1e-6)\n\n        self.add_q_proj = nn.Linear(dim_b, dim_b)\n        self.add_k_proj = nn.Linear(dim_b, dim_b)\n        self.add_v_proj = nn.Linear(dim_b, dim_b)\n        self.norm_added_q = RMSNorm(head_dim, eps=1e-6)\n        self.norm_added_k = RMSNorm(head_dim, eps=1e-6)\n\n        self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))\n        self.to_add_out = nn.Linear(dim_b, dim_b)\n\n    def forward(\n        self,\n        image: torch.FloatTensor,\n        text: torch.FloatTensor,\n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        enable_fp8_attention: bool = False,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n        img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)\n        txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)\n        seq_txt = txt_q.shape[1]\n\n        img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)\n        img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)\n        img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)\n\n        txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)\n        txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)\n        txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)\n\n        img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)\n        txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)\n        \n        if image_rotary_emb is not None:\n            img_freqs, txt_freqs = image_rotary_emb\n            img_q = apply_rotary_emb_qwen(img_q, img_freqs)\n            img_k = apply_rotary_emb_qwen(img_k, img_freqs)\n            txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)\n            txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)\n\n        joint_q = torch.cat([txt_q, img_q], dim=2)\n        joint_k = torch.cat([txt_k, img_k], dim=2)\n        joint_v = torch.cat([txt_v, img_v], dim=2)\n\n        joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)\n\n        txt_attn_output = joint_attn_out[:, :seq_txt, :]\n        img_attn_output = joint_attn_out[:, seq_txt:, :]\n\n        img_attn_output = self.to_out(img_attn_output)\n        txt_attn_output = self.to_add_out(txt_attn_output)\n\n        return img_attn_output, txt_attn_output\n\n\nclass QwenImageTransformerBlock(nn.Module):\n    def __init__(\n        self, \n        dim: int, \n        num_attention_heads: int, \n        attention_head_dim: int, \n        eps: float = 1e-6,\n    ):    \n        super().__init__()\n        \n        self.dim = dim\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n\n        self.img_mod = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(dim, 6 * dim), \n        )\n        self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.attn = QwenDoubleStreamAttention(\n            dim_a=dim,\n            dim_b=dim,\n            num_heads=num_attention_heads,\n            head_dim=attention_head_dim,\n        )\n        self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)\n\n        self.txt_mod = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(dim, 6 * dim, bias=True), \n        )\n        self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)\n    \n    def _modulate(self, x, mod_params, index=None):\n        shift, scale, gate = mod_params.chunk(3, dim=-1)\n        if index is not None:\n            # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)\n            # So shift, scale, gate have shape [2*actual_batch, d]\n            actual_batch = shift.size(0) // 2\n            shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:]  # each: [actual_batch, d]\n            scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]\n            gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]\n\n            # index: [b, l] where b is actual batch size\n            # Expand to [b, l, 1] to match feature dimension\n            index_expanded = index.unsqueeze(-1)  # [b, l, 1]\n\n            # Expand chunks to [b, 1, d] then broadcast to [b, l, d]\n            shift_0_exp = shift_0.unsqueeze(1)  # [b, 1, d]\n            shift_1_exp = shift_1.unsqueeze(1)  # [b, 1, d]\n            scale_0_exp = scale_0.unsqueeze(1)\n            scale_1_exp = scale_1.unsqueeze(1)\n            gate_0_exp = gate_0.unsqueeze(1)\n            gate_1_exp = gate_1.unsqueeze(1)\n\n            # Use torch.where to select based on index\n            shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)\n            scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)\n            gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)\n        else:\n            shift_result = shift.unsqueeze(1)\n            scale_result = scale.unsqueeze(1)\n            gate_result = gate.unsqueeze(1)\n\n        return x * (1 + scale_result) + shift_result, gate_result\n\n    def forward(\n        self,\n        image: torch.Tensor,  \n        text: torch.Tensor,\n        temb: torch.Tensor, \n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        enable_fp8_attention = False,\n        modulate_index: Optional[List[int]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n\n        img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1)  # [B, 3*dim] each\n        if modulate_index is not None:\n            temb = torch.chunk(temb, 2, dim=0)[0]\n        txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1)  # [B, 3*dim] each\n\n        img_normed = self.img_norm1(image)\n        img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)\n\n        txt_normed = self.txt_norm1(text)\n        txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)\n\n        img_attn_out, txt_attn_out = self.attn(\n            image=img_modulated,\n            text=txt_modulated,\n            image_rotary_emb=image_rotary_emb,\n            attention_mask=attention_mask,\n            enable_fp8_attention=enable_fp8_attention,\n        )\n        \n        image = image + img_gate * img_attn_out\n        text = text + txt_gate * txt_attn_out\n\n        img_normed_2 = self.img_norm2(image)\n        img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)\n\n        txt_normed_2 = self.txt_norm2(text)\n        txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)\n\n        img_mlp_out = self.img_mlp(img_modulated_2)\n        txt_mlp_out = self.txt_mlp(txt_modulated_2)\n\n        image = image + img_gate_2 * img_mlp_out\n        text = text + txt_gate_2 * txt_mlp_out\n\n        return text, image\n\n\nclass QwenImageDiT(torch.nn.Module):\n    def __init__(\n        self,\n        num_layers: int = 60,\n        use_layer3d_rope: bool = False,\n        use_additional_t_cond: bool = False,\n    ):\n        super().__init__()\n\n        if not use_layer3d_rope:\n            self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)\n        else:\n            self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)\n\n        self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)\n        self.txt_norm = RMSNorm(3584, eps=1e-6)\n\n        self.img_in = nn.Linear(64, 3072)\n        self.txt_in = nn.Linear(3584, 3072)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                QwenImageTransformerBlock(\n                    dim=3072,\n                    num_attention_heads=24,\n                    attention_head_dim=128,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n        self.norm_out = AdaLayerNorm(3072, single=True)\n        self.proj_out = nn.Linear(3072, 64)\n\n\n    def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):\n        # prompt_emb\n        all_prompt_emb = entity_prompt_emb + [prompt_emb]\n        all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]\n        all_prompt_emb = torch.cat(all_prompt_emb, dim=1)\n\n        # image_rotary_emb\n        txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()\n        image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)\n        entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]\n        entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]\n        txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)\n        image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)\n\n        # attention_mask\n        repeat_dim = latents.shape[1]\n        max_masks = entity_masks.shape[1]\n        entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)\n        entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]\n        global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)\n        entity_masks = entity_masks + [global_mask]\n\n        N = len(entity_masks)\n        batch_size = entity_masks[0].shape[0]\n        seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]\n        total_seq_len = sum(seq_lens) + image.shape[1]\n        patched_masks = []\n        for i in range(N):\n            patched_mask = rearrange(entity_masks[i], \"B C (H P) (W Q) -> B (H W) (C P Q)\", H=height//16, W=width//16, P=2, Q=2)\n            patched_masks.append(patched_mask)\n        attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)\n\n        # prompt-image attention mask\n        image_start = sum(seq_lens)\n        image_end = total_seq_len\n        cumsum = [0]\n        single_image_seq = image_end - image_start\n        for length in seq_lens:\n            cumsum.append(cumsum[-1] + length)\n        for i in range(N):\n            prompt_start = cumsum[i]\n            prompt_end = cumsum[i+1]\n            image_mask = torch.sum(patched_masks[i], dim=-1) > 0\n            image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)\n            # repeat image mask to match the single image sequence length\n            repeat_time = single_image_seq // image_mask.shape[-1]\n            image_mask = image_mask.repeat(1, 1, repeat_time)\n            # prompt update with image\n            attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask\n            # image update with prompt\n            attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)\n        # prompt-prompt attention mask, let the prompt tokens not attend to each other\n        for i in range(N):\n            for j in range(N):\n                if i == j:\n                    continue\n                start_i, end_i = cumsum[i], cumsum[i+1]\n                start_j, end_j = cumsum[j], cumsum[j+1]\n                attention_mask[:, start_i:end_i, start_j:end_j] = False\n\n        attention_mask = attention_mask.float()\n        attention_mask[attention_mask == 0] = float('-inf')\n        attention_mask[attention_mask == 1] = 0\n        attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)\n\n        return all_prompt_emb, image_rotary_emb, attention_mask\n\n\n    def forward(\n        self,\n        latents=None,\n        timestep=None,\n        prompt_emb=None,\n        prompt_emb_mask=None,\n        height=None,\n        width=None,\n    ):\n        img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]\n        txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()\n        \n        image = rearrange(latents, \"B C (H P) (W Q) -> B (H W) (C P Q)\", H=height//16, W=width//16, P=2, Q=2)\n        image = self.img_in(image)\n        text = self.txt_in(self.txt_norm(prompt_emb))\n\n        conditioning = self.time_text_embed(timestep, image.dtype)\n\n        image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)\n\n        for block in self.transformer_blocks:\n            text, image = block(\n                image=image,\n                text=text,\n                temb=conditioning,\n                image_rotary_emb=image_rotary_emb,\n            )\n        \n        image = self.norm_out(image, conditioning)\n        image = self.proj_out(image)\n        \n        latents = rearrange(image, \"B (H W) (C P Q) -> B C (H P) (W Q)\", H=height//16, W=width//16, P=2, Q=2)\n        return image\n"
  },
  {
    "path": "diffsynth/models/qwen_image_image2lora.py",
    "content": "import torch\n\n\nclass CompressedMLP(torch.nn.Module):\n    def __init__(self, in_dim, mid_dim, out_dim, bias=False):\n        super().__init__()\n        self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)\n        self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)\n        \n    def forward(self, x, residual=None):\n        x = self.proj_in(x)\n        if residual is not None: x = x + residual\n        x = self.proj_out(x)\n        return x\n\n\nclass ImageEmbeddingToLoraMatrix(torch.nn.Module):\n    def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):\n        super().__init__()\n        self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)\n        self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)\n        self.lora_a_dim = lora_a_dim\n        self.lora_b_dim = lora_b_dim\n        self.rank = rank\n        \n    def forward(self, x, residual=None):\n        lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)\n        lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)\n        return lora_a, lora_b\n\n\nclass SequencialMLP(torch.nn.Module):\n    def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):\n        super().__init__()\n        self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)\n        self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)\n        self.length = length\n        self.in_dim = in_dim\n        self.mid_dim = mid_dim\n        \n    def forward(self, x):\n        x = x.view(self.length, self.in_dim)\n        x = self.proj_in(x)\n        x = x.view(1, self.length * self.mid_dim)\n        x = self.proj_out(x)\n        return x\n\n\nclass LoRATrainerBlock(torch.nn.Module):\n    def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):\n        super().__init__()\n        self.lora_patterns = lora_patterns\n        self.block_id = block_id\n        self.layers = []\n        for name, lora_a_dim, lora_b_dim in self.lora_patterns:\n            self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))\n        self.layers = torch.nn.ModuleList(self.layers)\n        if use_residual:\n            self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)\n        else:\n            self.proj_residual = None\n    \n    def forward(self, x, residual=None):\n        lora = {}\n        if self.proj_residual is not None: residual = self.proj_residual(residual)\n        for lora_pattern, layer in zip(self.lora_patterns, self.layers):\n            name = lora_pattern[0]\n            lora_a, lora_b = layer(x, residual=residual)\n            lora[f\"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight\"] = lora_a\n            lora[f\"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight\"] = lora_b\n        return lora\n    \n\nclass QwenImageImage2LoRAModel(torch.nn.Module):\n    def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):\n        super().__init__()\n        self.lora_patterns = [\n            [\n                (\"attn.to_q\", 3072, 3072),\n                (\"attn.to_k\", 3072, 3072),\n                (\"attn.to_v\", 3072, 3072),\n                (\"attn.to_out.0\", 3072, 3072),\n            ],\n            [\n                (\"img_mlp.net.2\", 3072*4, 3072),\n                (\"img_mod.1\", 3072, 3072*6),\n            ],\n            [\n                (\"attn.add_q_proj\", 3072, 3072),\n                (\"attn.add_k_proj\", 3072, 3072),\n                (\"attn.add_v_proj\", 3072, 3072),\n                (\"attn.to_add_out\", 3072, 3072),\n            ],\n            [\n                (\"txt_mlp.net.2\", 3072*4, 3072),\n                (\"txt_mod.1\", 3072, 3072*6),\n            ],\n        ]\n        self.num_blocks = num_blocks\n        self.blocks = []\n        for lora_patterns in self.lora_patterns:\n            for block_id in range(self.num_blocks):\n                self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))\n        self.blocks = torch.nn.ModuleList(self.blocks)\n        self.residual_scale = 0.05\n        self.use_residual = use_residual\n        \n    def forward(self, x, residual=None):\n        if residual is not None:\n            if self.use_residual:\n                residual = residual * self.residual_scale\n            else:\n                residual = None\n        lora = {}\n        for block in self.blocks:\n            lora.update(block(x, residual))\n        return lora\n    \n    def initialize_weights(self):\n        state_dict = self.state_dict()\n        for name in state_dict:\n            if \".proj_a.\" in name:\n                state_dict[name] = state_dict[name] * 0.3\n            elif \".proj_b.proj_out.\" in name:\n                state_dict[name] = state_dict[name] * 0\n            elif \".proj_residual.proj_out.\" in name:\n                state_dict[name] = state_dict[name] * 0.3\n        self.load_state_dict(state_dict)\n"
  },
  {
    "path": "diffsynth/models/qwen_image_text_encoder.py",
    "content": "import torch\nfrom typing import Optional, Union\n\n\nclass QwenImageTextEncoder(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel\n        config = Qwen2_5_VLConfig(**{\n            \"architectures\": [\n                \"Qwen2_5_VLForConditionalGeneration\"\n            ],\n            \"attention_dropout\": 0.0,\n            \"bos_token_id\": 151643,\n            \"eos_token_id\": 151645,\n            \"hidden_act\": \"silu\",\n            \"hidden_size\": 3584,\n            \"image_token_id\": 151655,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 18944,\n            \"max_position_embeddings\": 128000,\n            \"max_window_layers\": 28,\n            \"model_type\": \"qwen2_5_vl\",\n            \"num_attention_heads\": 28,\n            \"num_hidden_layers\": 28,\n            \"num_key_value_heads\": 4,\n            \"rms_norm_eps\": 1e-06,\n            \"rope_scaling\": {\n                \"mrope_section\": [\n                    16,\n                    24,\n                    24\n                ],\n                \"rope_type\": \"default\",\n                \"type\": \"default\"\n            },\n            \"rope_theta\": 1000000.0,\n            \"sliding_window\": 32768,\n            \"text_config\": {\n                \"architectures\": [\n                    \"Qwen2_5_VLForConditionalGeneration\"\n                ],\n                \"attention_dropout\": 0.0,\n                \"bos_token_id\": 151643,\n                \"eos_token_id\": 151645,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 3584,\n                \"image_token_id\": None,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 18944,\n                \"layer_types\": [\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\",\n                \"full_attention\"\n                ],\n                \"max_position_embeddings\": 128000,\n                \"max_window_layers\": 28,\n                \"model_type\": \"qwen2_5_vl_text\",\n                \"num_attention_heads\": 28,\n                \"num_hidden_layers\": 28,\n                \"num_key_value_heads\": 4,\n                \"rms_norm_eps\": 1e-06,\n                \"rope_scaling\": {\n                \"mrope_section\": [\n                    16,\n                    24,\n                    24\n                ],\n                \"rope_type\": \"default\",\n                \"type\": \"default\"\n                },\n                \"rope_theta\": 1000000.0,\n                \"sliding_window\": None,\n                \"torch_dtype\": \"float32\",\n                \"use_cache\": True,\n                \"use_sliding_window\": False,\n                \"video_token_id\": None,\n                \"vision_end_token_id\": 151653,\n                \"vision_start_token_id\": 151652,\n                \"vision_token_id\": 151654,\n                \"vocab_size\": 152064\n            },\n            \"tie_word_embeddings\": False,\n            \"torch_dtype\": \"float32\",\n            \"transformers_version\": \"4.54.0\",\n            \"use_cache\": True,\n            \"use_sliding_window\": False,\n            \"video_token_id\": 151656,\n            \"vision_config\": {\n                \"depth\": 32,\n                \"fullatt_block_indexes\": [\n                    7,\n                    15,\n                    23,\n                    31\n                ],\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 1280,\n                \"in_channels\": 3,\n                \"in_chans\": 3,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 3420,\n                \"model_type\": \"qwen2_5_vl\",\n                \"num_heads\": 16,\n                \"out_hidden_size\": 3584,\n                \"patch_size\": 14,\n                \"spatial_merge_size\": 2,\n                \"spatial_patch_size\": 14,\n                \"temporal_patch_size\": 2,\n                \"tokens_per_second\": 2,\n                \"torch_dtype\": \"float32\",\n                \"window_size\": 112\n            },\n            \"vision_end_token_id\": 151653,\n            \"vision_start_token_id\": 151652,\n            \"vision_token_id\": 151654,\n            \"vocab_size\": 152064\n        })\n        self.model = Qwen2_5_VLModel(config)\n        self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)\n        self.config = config\n        \n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        pixel_values: Optional[torch.Tensor] = None,\n        pixel_values_videos: Optional[torch.FloatTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        rope_deltas: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs,\n    ):\n        output_attentions = False\n        output_hidden_states = True\n\n        outputs = self.model(\n            input_ids=input_ids,\n            pixel_values=pixel_values,\n            pixel_values_videos=pixel_values_videos,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        return outputs.hidden_states\n"
  },
  {
    "path": "diffsynth/models/qwen_image_vae.py",
    "content": "import torch\nfrom typing import List, Optional, Tuple, Union\nfrom torch import nn\n\n\nCACHE_T = 2\n\nclass QwenImageCausalConv3d(torch.nn.Conv3d):\n    r\"\"\"\n    A custom 3D causal convolution layer with feature caching support.\n\n    This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature\n    caching for efficient inference.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        padding: Union[int, Tuple[int, int, int]] = 0,\n    ) -> None:\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        # Set up causal padding\n        self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = torch.nn.functional.pad(x, padding)\n        return super().forward(x)\n\n\n\nclass QwenImageRMS_norm(nn.Module):\n    r\"\"\"\n    A custom RMS normalization layer.\n\n    Args:\n        dim (int): The number of dimensions to normalize over.\n        channel_first (bool, optional): Whether the input tensor has channels as the first dimension.\n            Default is True.\n        images (bool, optional): Whether the input represents image data. Default is True.\n        bias (bool, optional): Whether to include a learnable bias term. Default is False.\n    \"\"\"\n\n    def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0\n\n    def forward(self, x):\n        return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias\n\n\n\nclass QwenImageResidualBlock(nn.Module):\n    r\"\"\"\n    A custom residual block module.\n\n    Args:\n        in_dim (int): Number of input channels.\n        out_dim (int): Number of output channels.\n        dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.\n        non_linearity (str, optional): Type of non-linearity to use. Default is \"silu\".\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.nonlinearity = torch.nn.SiLU()\n\n        # layers\n        self.norm1 = QwenImageRMS_norm(in_dim, images=False)\n        self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)\n        self.norm2 = QwenImageRMS_norm(out_dim, images=False)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)\n        self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        # Apply shortcut connection\n        h = self.conv_shortcut(x)\n\n        # First normalization and activation\n        x = self.norm1(x)\n        x = self.nonlinearity(x)\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        # Second normalization and activation\n        x = self.norm2(x)\n        x = self.nonlinearity(x)\n\n        # Dropout\n        x = self.dropout(x)\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n\n            x = self.conv2(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv2(x)\n\n        # Add residual connection\n        return x + h\n\n\n\nclass QwenImageAttentionBlock(nn.Module):\n    r\"\"\"\n    Causal self-attention with a single head.\n\n    Args:\n        dim (int): The number of channels in the input tensor.\n    \"\"\"\n\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = QwenImageRMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n    def forward(self, x):\n        identity = x\n        batch_size, channels, time, height, width = x.size()\n\n        x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)\n        x = self.norm(x)\n\n        # compute query, key, value\n        qkv = self.to_qkv(x)\n        qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)\n        qkv = qkv.permute(0, 1, 3, 2).contiguous()\n        q, k, v = qkv.chunk(3, dim=-1)\n\n        # apply attention\n        x = torch.nn.functional.scaled_dot_product_attention(q, k, v)\n\n        x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)\n\n        # output projection\n        x = self.proj(x)\n\n        # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]\n        x = x.view(batch_size, time, channels, height, width)\n        x = x.permute(0, 2, 1, 3, 4)\n\n        return x + identity\n\n\n\nclass QwenImageUpsample(nn.Upsample):\n    r\"\"\"\n    Perform upsampling while ensuring the output tensor has the same data type as the input.\n\n    Args:\n        x (torch.Tensor): Input tensor to be upsampled.\n\n    Returns:\n        torch.Tensor: Upsampled tensor with the same data type as the input.\n    \"\"\"\n\n    def forward(self, x):\n        return super().forward(x.float()).type_as(x)\n\n\n\nclass QwenImageResample(nn.Module):\n    r\"\"\"\n    A custom resampling module for 2D and 3D data.\n\n    Args:\n        dim (int): The number of input/output channels.\n        mode (str): The resampling mode. Must be one of:\n            - 'none': No resampling (identity operation).\n            - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.\n            - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.\n            - 'downsample2d': 2D downsampling with zero-padding and convolution.\n            - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.\n    \"\"\"\n\n    def __init__(self, dim: int, mode: str) -> None:\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                QwenImageUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"), nn.Conv2d(dim, dim // 2, 3, padding=1)\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                QwenImageUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"), nn.Conv2d(dim, dim // 2, 3, padding=1)\n            )\n            self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n            self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        b, c, t, h, w = x.size()\n        if self.mode == \"upsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = \"Rep\"\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                    if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != \"Rep\":\n                        # cache last frame of last two chunk\n                        cache_x = torch.cat(\n                            [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2\n                        )\n                    if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == \"Rep\":\n                        cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)\n                    if feat_cache[idx] == \"Rep\":\n                        x = self.time_conv(x)\n                    else:\n                        x = self.time_conv(x, feat_cache[idx])\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n\n                    x = x.reshape(b, 2, c, t, h, w)\n                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)\n                    x = x.reshape(b, c, t * 2, h, w)\n        t = x.shape[2]\n        x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)\n        x = self.resample(x)\n        x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)\n\n        if self.mode == \"downsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = x.clone()\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -1:, :, :].clone()\n                    x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n        return x\n\n\n\nclass QwenImageMidBlock(nn.Module):\n    \"\"\"\n    Middle block for WanVAE encoder and decoder.\n\n    Args:\n        dim (int): Number of input/output channels.\n        dropout (float): Dropout rate.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = \"silu\", num_layers: int = 1):\n        super().__init__()\n        self.dim = dim\n\n        # Create the components\n        resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]\n        attentions = []\n        for _ in range(num_layers):\n            attentions.append(QwenImageAttentionBlock(dim))\n            resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        # First residual block\n        x = self.resnets[0](x, feat_cache, feat_idx)\n\n        # Process through attention and residual blocks\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if attn is not None:\n                x = attn(x)\n\n            x = resnet(x, feat_cache, feat_idx)\n\n        return x\n\n\n\nclass QwenImageEncoder3d(nn.Module):\n    r\"\"\"\n    A 3D encoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_downsample (list of bool): Whether to downsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[True, True, False],\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        image_channels=3\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.nonlinearity = torch.nn.SiLU()\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)\n\n        # downsample blocks\n        self.down_blocks = torch.nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            for _ in range(num_res_blocks):\n                self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))\n                if scale in attn_scales:\n                    self.down_blocks.append(QwenImageAttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # downsample block\n            if i != len(dim_mult) - 1:\n                mode = \"downsample3d\" if temperal_downsample[i] else \"downsample2d\"\n                self.down_blocks.append(QwenImageResample(out_dim, mode=mode))\n                scale /= 2.0\n\n        # middle blocks\n        self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)\n\n        # output blocks\n        self.norm_out = QwenImageRMS_norm(out_dim, images=False)\n        self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n            x = self.conv_in(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_in(x)\n\n        ## downsamples\n        for layer in self.down_blocks:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        x = self.mid_block(x, feat_cache, feat_idx)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n            x = self.conv_out(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_out(x)\n        return x\n\n\n\nclass QwenImageUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the WanVAE decoder.\n\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        upsample_mode: Optional[str] = None,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # Create layers list\n        resnets = []\n        # Add residual blocks and attention if needed\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        self.upsamplers = None\n        if upsample_mode is not None:\n            self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        \"\"\"\n        Forward pass through the upsampling block.\n\n        Args:\n            x (torch.Tensor): Input tensor\n            feat_cache (list, optional): Feature cache for causal convolutions\n            feat_idx (list, optional): Feature index for cache management\n\n        Returns:\n            torch.Tensor: Output tensor\n        \"\"\"\n        for resnet in self.resnets:\n            if feat_cache is not None:\n                x = resnet(x, feat_cache, feat_idx)\n            else:\n                x = resnet(x)\n\n        if self.upsamplers is not None:\n            if feat_cache is not None:\n                x = self.upsamplers[0](x, feat_cache, feat_idx)\n            else:\n                x = self.upsamplers[0](x)\n        return x\n\n\n\nclass QwenImageDecoder3d(nn.Module):\n    r\"\"\"\n    A 3D decoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_upsample (list of bool): Whether to upsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_upsample=[False, True, True],\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        image_channels=3,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        self.nonlinearity = torch.nn.SiLU()\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2 ** (len(dim_mult) - 2)\n\n        # init block\n        self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)\n\n        # upsample blocks\n        self.up_blocks = nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            if i > 0:\n                in_dim = in_dim // 2\n\n            # Determine if we need upsampling\n            upsample_mode = None\n            if i != len(dim_mult) - 1:\n                upsample_mode = \"upsample3d\" if temperal_upsample[i] else \"upsample2d\"\n\n            # Create and add the upsampling block\n            up_block = QwenImageUpBlock(\n                in_dim=in_dim,\n                out_dim=out_dim,\n                num_res_blocks=num_res_blocks,\n                dropout=dropout,\n                upsample_mode=upsample_mode,\n                non_linearity=non_linearity,\n            )\n            self.up_blocks.append(up_block)\n\n            # Update scale for next iteration\n            if upsample_mode is not None:\n                scale *= 2.0\n\n        # output blocks\n        self.norm_out = QwenImageRMS_norm(out_dim, images=False)\n        self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        ## conv1\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n            x = self.conv_in(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_in(x)\n\n        ## middle\n        x = self.mid_block(x, feat_cache, feat_idx)\n\n        ## upsamples\n        for up_block in self.up_blocks:\n            x = up_block(x, feat_cache, feat_idx)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n            x = self.conv_out(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_out(x)\n        return x\n\n\n\nclass QwenImageVAE(torch.nn.Module):\n    def __init__(\n        self,\n        base_dim: int = 96,\n        z_dim: int = 16,\n        dim_mult: Tuple[int] = [1, 2, 4, 4],\n        num_res_blocks: int = 2,\n        attn_scales: List[float] = [],\n        temperal_downsample: List[bool] = [False, True, True],\n        dropout: float = 0.0,\n        image_channels: int = 3,\n    ) -> None:\n        super().__init__()\n\n        self.z_dim = z_dim\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n\n        self.encoder = QwenImageEncoder3d(\n            base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,\n        )\n        self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)\n\n        self.decoder = QwenImageDecoder3d(\n            base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,\n        )\n\n        mean = [\n            -0.7571,\n            -0.7089,\n            -0.9113,\n            0.1075,\n            -0.1745,\n            0.9653,\n            -0.1517,\n            1.5508,\n            0.4134,\n            -0.0715,\n            0.5517,\n            -0.3632,\n            -0.1922,\n            -0.9497,\n            0.2503,\n            -0.2921,\n        ]\n        std = [\n            2.8184,\n            1.4541,\n            2.3275,\n            2.6558,\n            1.2196,\n            1.7708,\n            2.6052,\n            2.0743,\n            3.2687,\n            2.1526,\n            2.8652,\n            1.5579,\n            1.6382,\n            1.1253,\n            2.8251,\n            1.9160,\n        ]\n        self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)\n        self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)\n\n    def encode(self, x, **kwargs):\n        x = x.unsqueeze(2)\n        x = self.encoder(x)\n        x = self.quant_conv(x)\n        x = x[:, :16]\n        mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)\n        x = (x - mean) * std\n        x = x.squeeze(2)\n        return x\n    \n    def decode(self, x, **kwargs):\n        x = x.unsqueeze(2)\n        mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)\n        x = x / std + mean\n        x = self.post_quant_conv(x)\n        x = self.decoder(x)\n        x = x.squeeze(2)\n        return x\n"
  },
  {
    "path": "diffsynth/models/sd_text_encoder.py",
    "content": "import torch\nfrom .attention import Attention\nfrom einops import rearrange\n\n\ndef low_version_attention(query, key, value, attn_bias=None):\n    scale = 1 / query.shape[-1] ** 0.5\n    query = query * scale\n    attn = torch.matmul(query, key.transpose(-2, -1))\n    if attn_bias is not None:\n        attn = attn + attn_bias\n    attn = attn.softmax(-1)\n    return attn @ value\n\n\nclass Attention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)\n        self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)\n\n    def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):\n        batch_size = q.shape[0]\n        ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)\n        hidden_states = hidden_states + scale * ip_hidden_states\n        return hidden_states\n\n    def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        batch_size = encoder_hidden_states.shape[0]\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n\n        if qkv_preprocessor is not None:\n            q, k, v = qkv_preprocessor(q, k, v)\n\n        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        if ipadapter_kwargs is not None:\n            hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)\n        hidden_states = hidden_states.to(q.dtype)\n\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n    \n    def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        q = self.to_q(hidden_states)\n        k = self.to_k(encoder_hidden_states)\n        v = self.to_v(encoder_hidden_states)\n\n        q = rearrange(q, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n        k = rearrange(k, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n        v = rearrange(v, \"b f (n d) -> (b n) f d\", n=self.num_heads)\n\n        if attn_mask is not None:\n            hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)\n        else:\n            import xformers.ops as xops\n            hidden_states = xops.memory_efficient_attention(q, k, v)\n        hidden_states = rearrange(hidden_states, \"(b n) f d -> b f (n d)\", n=self.num_heads)\n\n        hidden_states = hidden_states.to(q.dtype)\n        hidden_states = self.to_out(hidden_states)\n\n        return hidden_states\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):\n        return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)\n\n\n\n\n\nclass CLIPEncoderLayer(torch.nn.Module):\n    def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):\n        super().__init__()\n        self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)\n        self.layer_norm1 = torch.nn.LayerNorm(embed_dim)\n        self.layer_norm2 = torch.nn.LayerNorm(embed_dim)\n        self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)\n        self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)\n\n        self.use_quick_gelu = use_quick_gelu\n\n    def quickGELU(self, x):\n        return x * torch.sigmoid(1.702 * x)\n    \n    def forward(self, hidden_states, attn_mask=None):\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.attn(hidden_states, attn_mask=attn_mask)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.fc1(hidden_states)\n        if self.use_quick_gelu:\n            hidden_states = self.quickGELU(hidden_states)\n        else:\n            hidden_states = torch.nn.functional.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n    \n\nclass SDTextEncoder(torch.nn.Module):\n    def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):\n        super().__init__()\n\n        # token_embedding\n        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)\n\n        # position_embeds (This is a fixed tensor)\n        self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))\n\n        # encoders\n        self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])\n\n        # attn_mask\n        self.attn_mask = self.attention_mask(max_position_embeddings)\n\n        # final_layer_norm\n        self.final_layer_norm = torch.nn.LayerNorm(embed_dim)\n\n    def attention_mask(self, length):\n        mask = torch.empty(length, length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)\n        return mask\n\n    def forward(self, input_ids, clip_skip=1):\n        embeds = self.token_embedding(input_ids) + self.position_embeds\n        attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)\n        for encoder_id, encoder in enumerate(self.encoders):\n            embeds = encoder(embeds, attn_mask=attn_mask)\n            if encoder_id + clip_skip == len(self.encoders):\n                break\n        embeds = self.final_layer_norm(embeds)\n        return embeds\n    \n    @staticmethod\n    def state_dict_converter():\n        return SDTextEncoderStateDictConverter()\n\n\nclass SDTextEncoderStateDictConverter:\n    def __init__(self):\n        pass\n\n    def from_diffusers(self, state_dict):\n        rename_dict = {\n            \"text_model.embeddings.token_embedding.weight\": \"token_embedding.weight\",\n            \"text_model.embeddings.position_embedding.weight\": \"position_embeds\",\n            \"text_model.final_layer_norm.weight\": \"final_layer_norm.weight\",\n            \"text_model.final_layer_norm.bias\": \"final_layer_norm.bias\"\n        }\n        attn_rename_dict = {\n            \"self_attn.q_proj\": \"attn.to_q\",\n            \"self_attn.k_proj\": \"attn.to_k\",\n            \"self_attn.v_proj\": \"attn.to_v\",\n            \"self_attn.out_proj\": \"attn.to_out\",\n            \"layer_norm1\": \"layer_norm1\",\n            \"layer_norm2\": \"layer_norm2\",\n            \"mlp.fc1\": \"fc1\",\n            \"mlp.fc2\": \"fc2\",\n        }\n        state_dict_ = {}\n        for name in state_dict:\n            if name in rename_dict:\n                param = state_dict[name]\n                if name == \"text_model.embeddings.position_embedding.weight\":\n                    param = param.reshape((1, param.shape[0], param.shape[1]))\n                state_dict_[rename_dict[name]] = param\n            elif name.startswith(\"text_model.encoder.layers.\"):\n                param = state_dict[name]\n                names = name.split(\".\")\n                layer_id, layer_type, tail = names[3], \".\".join(names[4:-1]), names[-1]\n                name_ = \".\".join([\"encoders\", layer_id, attn_rename_dict[layer_type], tail])\n                state_dict_[name_] = param\n        return state_dict_\n    \n    def from_civitai(self, state_dict):\n        rename_dict = {\n            \"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight\": \"token_embedding.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias\": \"encoders.0.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight\": \"encoders.0.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias\": \"encoders.0.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight\": \"encoders.0.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias\": \"encoders.0.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight\": \"encoders.0.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias\": \"encoders.0.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight\": \"encoders.0.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias\": \"encoders.0.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight\": \"encoders.0.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias\": \"encoders.0.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight\": \"encoders.0.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias\": \"encoders.0.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight\": \"encoders.0.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias\": \"encoders.0.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight\": \"encoders.0.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias\": \"encoders.1.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight\": \"encoders.1.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias\": \"encoders.1.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight\": \"encoders.1.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias\": \"encoders.1.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight\": \"encoders.1.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias\": \"encoders.1.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight\": \"encoders.1.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias\": \"encoders.1.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight\": \"encoders.1.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias\": \"encoders.1.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight\": \"encoders.1.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias\": \"encoders.1.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight\": \"encoders.1.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias\": \"encoders.1.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight\": \"encoders.1.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias\": \"encoders.10.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight\": \"encoders.10.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias\": \"encoders.10.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight\": \"encoders.10.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias\": \"encoders.10.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight\": \"encoders.10.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias\": \"encoders.10.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight\": \"encoders.10.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias\": \"encoders.10.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight\": \"encoders.10.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias\": \"encoders.10.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight\": \"encoders.10.attn.to_out.weight\",        \n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias\": \"encoders.10.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight\": \"encoders.10.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias\": \"encoders.10.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight\": \"encoders.10.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias\": \"encoders.11.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight\": \"encoders.11.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias\": \"encoders.11.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight\": \"encoders.11.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias\": \"encoders.11.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight\": \"encoders.11.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias\": \"encoders.11.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight\": \"encoders.11.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias\": \"encoders.11.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight\": \"encoders.11.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias\": \"encoders.11.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight\": \"encoders.11.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias\": \"encoders.11.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight\": \"encoders.11.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias\": \"encoders.11.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight\": \"encoders.11.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias\": \"encoders.2.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight\": \"encoders.2.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias\": \"encoders.2.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight\": \"encoders.2.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias\": \"encoders.2.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight\": \"encoders.2.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias\": \"encoders.2.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight\": \"encoders.2.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias\": \"encoders.2.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight\": \"encoders.2.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias\": \"encoders.2.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight\": \"encoders.2.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias\": \"encoders.2.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight\": \"encoders.2.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias\": \"encoders.2.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight\": \"encoders.2.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias\": \"encoders.3.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight\": \"encoders.3.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias\": \"encoders.3.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight\": \"encoders.3.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias\": \"encoders.3.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight\": \"encoders.3.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias\": \"encoders.3.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight\": \"encoders.3.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias\": \"encoders.3.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight\": \"encoders.3.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias\": \"encoders.3.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight\": \"encoders.3.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias\": \"encoders.3.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight\": \"encoders.3.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias\": \"encoders.3.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight\": \"encoders.3.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias\": \"encoders.4.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight\": \"encoders.4.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias\": \"encoders.4.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight\": \"encoders.4.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias\": \"encoders.4.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight\": \"encoders.4.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias\": \"encoders.4.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight\": \"encoders.4.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias\": \"encoders.4.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight\": \"encoders.4.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias\": \"encoders.4.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight\": \"encoders.4.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias\": \"encoders.4.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight\": \"encoders.4.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias\": \"encoders.4.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight\": \"encoders.4.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias\": \"encoders.5.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight\": \"encoders.5.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias\": \"encoders.5.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight\": \"encoders.5.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias\": \"encoders.5.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight\": \"encoders.5.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias\": \"encoders.5.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight\": \"encoders.5.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias\": \"encoders.5.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight\": \"encoders.5.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias\": \"encoders.5.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight\": \"encoders.5.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias\": \"encoders.5.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight\": \"encoders.5.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias\": \"encoders.5.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight\": \"encoders.5.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias\": \"encoders.6.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight\": \"encoders.6.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias\": \"encoders.6.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight\": \"encoders.6.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias\": \"encoders.6.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight\": \"encoders.6.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias\": \"encoders.6.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight\": \"encoders.6.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias\": \"encoders.6.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight\": \"encoders.6.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias\": \"encoders.6.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight\": \"encoders.6.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias\": \"encoders.6.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight\": \"encoders.6.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias\": \"encoders.6.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight\": \"encoders.6.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias\": \"encoders.7.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight\": \"encoders.7.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias\": \"encoders.7.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight\": \"encoders.7.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias\": \"encoders.7.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight\": \"encoders.7.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias\": \"encoders.7.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight\": \"encoders.7.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias\": \"encoders.7.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight\": \"encoders.7.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias\": \"encoders.7.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight\": \"encoders.7.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias\": \"encoders.7.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight\": \"encoders.7.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias\": \"encoders.7.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight\": \"encoders.7.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias\": \"encoders.8.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight\": \"encoders.8.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias\": \"encoders.8.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight\": \"encoders.8.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias\": \"encoders.8.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight\": \"encoders.8.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias\": \"encoders.8.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight\": \"encoders.8.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias\": \"encoders.8.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight\": \"encoders.8.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias\": \"encoders.8.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight\": \"encoders.8.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias\": \"encoders.8.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight\": \"encoders.8.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias\": \"encoders.8.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight\": \"encoders.8.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias\": \"encoders.9.layer_norm1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight\": \"encoders.9.layer_norm1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias\": \"encoders.9.layer_norm2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight\": \"encoders.9.layer_norm2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias\": \"encoders.9.fc1.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight\": \"encoders.9.fc1.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias\": \"encoders.9.fc2.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight\": \"encoders.9.fc2.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias\": \"encoders.9.attn.to_k.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight\": \"encoders.9.attn.to_k.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias\": \"encoders.9.attn.to_out.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight\": \"encoders.9.attn.to_out.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias\": \"encoders.9.attn.to_q.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight\": \"encoders.9.attn.to_q.weight\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias\": \"encoders.9.attn.to_v.bias\",\n            \"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight\": \"encoders.9.attn.to_v.weight\",\n            \"cond_stage_model.transformer.text_model.final_layer_norm.bias\": \"final_layer_norm.bias\",\n            \"cond_stage_model.transformer.text_model.final_layer_norm.weight\": \"final_layer_norm.weight\",\n            \"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight\": \"position_embeds\"\n        }\n        state_dict_ = {}\n        for name in state_dict:\n            if name in rename_dict:\n                param = state_dict[name]\n                if name == \"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight\":\n                    param = param.reshape((1, param.shape[0], param.shape[1]))\n                state_dict_[rename_dict[name]] = param\n        return state_dict_\n"
  },
  {
    "path": "diffsynth/models/siglip2_image_encoder.py",
    "content": "from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig\nfrom transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast\nimport torch\n\nfrom diffsynth.core.device.npu_compatible_device import get_device_type\n\n\nclass Siglip2ImageEncoder(SiglipVisionTransformer):\n    def __init__(self):\n        config = SiglipVisionConfig(\n            attention_dropout = 0.0,\n            dtype = \"float32\",\n            hidden_act = \"gelu_pytorch_tanh\",\n            hidden_size = 1536,\n            image_size = 384,\n            intermediate_size = 6144,\n            layer_norm_eps = 1e-06,\n            model_type = \"siglip_vision_model\",\n            num_attention_heads = 16,\n            num_channels = 3,\n            num_hidden_layers = 40,\n            patch_size = 16,\n            transformers_version = \"4.56.1\",\n            _attn_implementation = \"sdpa\"\n        )\n        super().__init__(config)\n        self.processor = SiglipImageProcessor(\n            do_convert_rgb = None,\n            do_normalize = True,\n            do_rescale = True,\n            do_resize = True,\n            image_mean = [\n                0.5,\n                0.5,\n                0.5\n            ],\n            image_processor_type = \"SiglipImageProcessor\",\n            image_std = [\n                0.5,\n                0.5,\n                0.5\n            ],\n            processor_class = \"SiglipProcessor\",\n            resample = 2,\n            rescale_factor = 0.00392156862745098,\n            size = {\n                \"height\": 384,\n                \"width\": 384\n            }\n        )\n        \n    def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):\n        pixel_values = self.processor(images=[image], return_tensors=\"pt\")[\"pixel_values\"]\n        pixel_values = pixel_values.to(device=device, dtype=torch_dtype)\n        output_attentions = False\n        output_hidden_states = False\n        interpolate_pos_encoding = False\n\n        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        last_hidden_state = encoder_outputs.last_hidden_state\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        pooler_output = self.head(last_hidden_state) if self.use_head else None\n\n        return pooler_output\n\n\nclass Siglip2ImageEncoder428M(Siglip2VisionModel):\n    def __init__(self):\n        config = Siglip2VisionConfig(\n            attention_dropout = 0.0,\n            dtype = \"bfloat16\",\n            hidden_act = \"gelu_pytorch_tanh\",\n            hidden_size = 1152,\n            intermediate_size = 4304,\n            layer_norm_eps = 1e-06,\n            model_type = \"siglip2_vision_model\",\n            num_attention_heads = 16,\n            num_channels = 3,\n            num_hidden_layers = 27,\n            num_patches = 256,\n            patch_size = 16,\n            transformers_version = \"4.57.1\"\n        )\n        super().__init__(config)\n        self.processor = Siglip2ImageProcessorFast(\n            **{\n                \"data_format\": \"channels_first\",\n                \"default_to_square\": True,\n                \"device\": None,\n                \"disable_grouping\": None,\n                \"do_convert_rgb\": None,\n                \"do_normalize\": True,\n                \"do_pad\": None,\n                \"do_rescale\": True,\n                \"do_resize\": True,\n                \"image_mean\": [\n                    0.5,\n                    0.5,\n                    0.5\n                ],\n                \"image_processor_type\": \"Siglip2ImageProcessorFast\",\n                \"image_std\": [\n                    0.5,\n                    0.5,\n                    0.5\n                ],\n                \"input_data_format\": None,\n                \"max_num_patches\": 256,\n                \"pad_size\": None,\n                \"patch_size\": 16,\n                \"processor_class\": \"Siglip2Processor\",\n                \"resample\": 2,\n                \"rescale_factor\": 0.00392156862745098,\n                \"return_tensors\": None,\n            }\n        )\n        \n    def forward(self, image, torch_dtype=torch.bfloat16, device=\"cuda\"):\n        siglip_inputs = self.processor(images=[image], return_tensors=\"pt\").to(device)\n        shape = siglip_inputs.spatial_shapes[0]\n        hidden_state = super().forward(**siglip_inputs).last_hidden_state\n        B, N, C = hidden_state.shape\n        hidden_state = hidden_state[:, : shape[0] * shape[1]]\n        hidden_state = hidden_state.view(shape[0], shape[1], C)\n        hidden_state = hidden_state.to(torch_dtype)\n        return hidden_state\n"
  },
  {
    "path": "diffsynth/models/step1x_connector.py",
    "content": "from typing import Optional\n\nimport torch, math\nimport torch.nn\nfrom einops import rearrange\nfrom torch import nn\nfrom functools import partial\nfrom einops import rearrange\n\n\n\ndef attention(q, k, v, attn_mask, mode=\"torch\"):\n    q = q.transpose(1, 2)\n    k = k.transpose(1, 2)\n    v = v.transpose(1, 2)\n    x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n    x = rearrange(x, \"b n s d -> b s (n d)\")\n    return x\n    \n\n\nclass MLP(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        norm_layer=None,\n        bias=True,\n        drop=0.0,\n        use_conv=False,\n        device=None,\n        dtype=None,\n    ):\n        super().__init__()\n        out_features = out_features or in_channels\n        hidden_channels = hidden_channels or in_channels\n        bias = (bias, bias)\n        drop_probs = (drop, drop)\n        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear\n\n        self.fc1 = linear_layer(\n            in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype\n        )\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.norm = (\n            norm_layer(hidden_channels, device=device, dtype=dtype)\n            if norm_layer is not None\n            else nn.Identity()\n        )\n        self.fc2 = linear_layer(\n            hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype\n        )\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.norm(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n    \n    \nclass TextProjection(nn.Module):\n    \"\"\"\n    Projects text embeddings. Also handles dropout for classifier-free guidance.\n\n    Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py\n    \"\"\"\n\n    def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n        self.linear_1 = nn.Linear(\n            in_features=in_channels,\n            out_features=hidden_size,\n            bias=True,\n            **factory_kwargs,\n        )\n        self.act_1 = act_layer()\n        self.linear_2 = nn.Linear(\n            in_features=hidden_size,\n            out_features=hidden_size,\n            bias=True,\n            **factory_kwargs,\n        )\n\n    def forward(self, caption):\n        hidden_states = self.linear_1(caption)\n        hidden_states = self.act_1(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n    \n    \nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size,\n        act_layer,\n        frequency_embedding_size=256,\n        max_period=10000,\n        out_size=None,\n        dtype=None,\n        device=None,\n    ):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n        self.frequency_embedding_size = frequency_embedding_size\n        self.max_period = max_period\n        if out_size is None:\n            out_size = hidden_size\n\n        self.mlp = nn.Sequential(\n            nn.Linear(\n                frequency_embedding_size, hidden_size, bias=True, **factory_kwargs\n            ),\n            act_layer(),\n            nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),\n        )\n        nn.init.normal_(self.mlp[0].weight, std=0.02)  # type: ignore\n        nn.init.normal_(self.mlp[2].weight, std=0.02)  # type: ignore\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n\n        Args:\n            t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.\n            dim (int): the dimension of the output.\n            max_period (int): controls the minimum frequency of the embeddings.\n\n        Returns:\n            embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.\n\n        .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n        \"\"\"\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n            )\n        return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(\n            t, self.frequency_embedding_size, self.max_period\n        ).type(t.dtype)  # type: ignore\n        t_emb = self.mlp(t_freq)\n        return t_emb\n    \n    \ndef apply_gate(x, gate=None, tanh=False):\n    \"\"\"AI is creating summary for apply_gate\n\n    Args:\n        x (torch.Tensor): input tensor.\n        gate (torch.Tensor, optional): gate tensor. Defaults to None.\n        tanh (bool, optional): whether to use tanh function. Defaults to False.\n\n    Returns:\n        torch.Tensor: the output tensor after apply gate.\n    \"\"\"\n    if gate is None:\n        return x\n    if tanh:\n        return x * gate.unsqueeze(1).tanh()\n    else:\n        return x * gate.unsqueeze(1)\n\n\nclass RMSNorm(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        elementwise_affine=True,\n        eps: float = 1e-6,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        Initialize the RMSNorm normalization layer.\n\n        Args:\n            dim (int): The dimension of the input tensor.\n            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n\n        Attributes:\n            eps (float): A small value added to the denominator for numerical stability.\n            weight (nn.Parameter): Learnable scaling parameter.\n\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))\n\n    def _norm(self, x):\n        \"\"\"\n        Apply the RMSNorm normalization to the input tensor.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The normalized tensor.\n\n        \"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass through the RMSNorm layer.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The output tensor after applying RMSNorm.\n\n        \"\"\"\n        output = self._norm(x.float()).type_as(x)\n        if hasattr(self, \"weight\"):\n            output = output * self.weight\n        return output\n\n\ndef get_norm_layer(norm_layer):\n    \"\"\"\n    Get the normalization layer.\n\n    Args:\n        norm_layer (str): The type of normalization layer.\n\n    Returns:\n        norm_layer (nn.Module): The normalization layer.\n    \"\"\"\n    if norm_layer == \"layer\":\n        return nn.LayerNorm\n    elif norm_layer == \"rms\":\n        return RMSNorm\n    else:\n        raise NotImplementedError(f\"Norm layer {norm_layer} is not implemented\")\n\n\ndef get_activation_layer(act_type):\n    \"\"\"get activation layer\n\n    Args:\n        act_type (str): the activation type\n\n    Returns:\n        torch.nn.functional: the activation layer\n    \"\"\"\n    if act_type == \"gelu\":\n        return lambda: nn.GELU()\n    elif act_type == \"gelu_tanh\":\n        return lambda: nn.GELU(approximate=\"tanh\")\n    elif act_type == \"relu\":\n        return nn.ReLU\n    elif act_type == \"silu\":\n        return nn.SiLU\n    else:\n        raise ValueError(f\"Unknown activation type: {act_type}\")\n\nclass IndividualTokenRefinerBlock(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        heads_num,\n        mlp_width_ratio: str = 4.0,\n        mlp_drop_rate: float = 0.0,\n        act_type: str = \"silu\",\n        qk_norm: bool = False,\n        qk_norm_type: str = \"layer\",\n        qkv_bias: bool = True,\n        need_CA: bool = False,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.need_CA = need_CA\n        self.heads_num = heads_num\n        head_dim = hidden_size // heads_num\n        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)\n\n        self.norm1 = nn.LayerNorm(\n            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs\n        )\n        self.self_attn_qkv = nn.Linear(\n            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs\n        )\n        qk_norm_layer = get_norm_layer(qk_norm_type)\n        self.self_attn_q_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.self_attn_k_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.self_attn_proj = nn.Linear(\n            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs\n        )\n\n        self.norm2 = nn.LayerNorm(\n            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs\n        )\n        act_layer = get_activation_layer(act_type)\n        self.mlp = MLP(\n            in_channels=hidden_size,\n            hidden_channels=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=mlp_drop_rate,\n            **factory_kwargs,\n        )\n\n        self.adaLN_modulation = nn.Sequential(\n            act_layer(),\n            nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),\n        )\n\n        if self.need_CA:\n            self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,\n                        heads_num=heads_num,\n                        mlp_width_ratio=mlp_width_ratio,\n                        mlp_drop_rate=mlp_drop_rate,\n                        act_type=act_type,\n                        qk_norm=qk_norm,\n                        qk_norm_type=qk_norm_type,\n                        qkv_bias=qkv_bias,\n                        **factory_kwargs,)\n        # Zero-initialize the modulation\n        nn.init.zeros_(self.adaLN_modulation[1].weight)\n        nn.init.zeros_(self.adaLN_modulation[1].bias)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        c: torch.Tensor,  # timestep_aware_representations + context_aware_representations\n        attn_mask: torch.Tensor = None,\n        y: torch.Tensor = None,\n    ):\n        gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)\n\n        norm_x = self.norm1(x)\n        qkv = self.self_attn_qkv(norm_x)\n        q, k, v = rearrange(qkv, \"B L (K H D) -> K B L H D\", K=3, H=self.heads_num)\n        # Apply QK-Norm if needed\n        q = self.self_attn_q_norm(q).to(v)\n        k = self.self_attn_k_norm(k).to(v)\n\n        # Self-Attention\n        attn = attention(q, k, v, mode=\"torch\", attn_mask=attn_mask)\n\n        x = x + apply_gate(self.self_attn_proj(attn), gate_msa)\n        \n        if self.need_CA:\n            x = self.cross_attnblock(x, c, attn_mask, y)\n\n        # FFN Layer\n        x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)\n\n        return x\n\n\n\n\nclass CrossAttnBlock(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        heads_num,\n        mlp_width_ratio: str = 4.0,\n        mlp_drop_rate: float = 0.0,\n        act_type: str = \"silu\",\n        qk_norm: bool = False,\n        qk_norm_type: str = \"layer\",\n        qkv_bias: bool = True,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.heads_num = heads_num\n        head_dim = hidden_size // heads_num\n\n        self.norm1 = nn.LayerNorm(\n            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs\n        )\n        self.norm1_2 = nn.LayerNorm(\n            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs\n        )\n        self.self_attn_q = nn.Linear(\n            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs\n        )\n        self.self_attn_kv = nn.Linear(\n            hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs\n        )\n        qk_norm_layer = get_norm_layer(qk_norm_type)\n        self.self_attn_q_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.self_attn_k_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.self_attn_proj = nn.Linear(\n            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs\n        )\n\n        self.norm2 = nn.LayerNorm(\n            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs\n        )\n        act_layer = get_activation_layer(act_type)\n\n        self.adaLN_modulation = nn.Sequential(\n            act_layer(),\n            nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),\n        )\n        # Zero-initialize the modulation\n        nn.init.zeros_(self.adaLN_modulation[1].weight)\n        nn.init.zeros_(self.adaLN_modulation[1].bias)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        c: torch.Tensor,  # timestep_aware_representations + context_aware_representations\n        attn_mask: torch.Tensor = None,\n        y: torch.Tensor=None,\n        \n    ):\n        gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)\n\n        norm_x = self.norm1(x)\n        norm_y = self.norm1_2(y)\n        q = self.self_attn_q(norm_x)\n        q = rearrange(q, \"B L (H D) -> B L H D\",  H=self.heads_num)\n        kv = self.self_attn_kv(norm_y)\n        k, v = rearrange(kv, \"B L (K H D) -> K B L H D\", K=2, H=self.heads_num)\n        # Apply QK-Norm if needed\n        q = self.self_attn_q_norm(q).to(v)\n        k = self.self_attn_k_norm(k).to(v)\n\n        # Self-Attention\n        attn = attention(q, k, v, mode=\"torch\", attn_mask=attn_mask)\n\n        x = x + apply_gate(self.self_attn_proj(attn), gate_msa)\n\n        return x\n\n\n\nclass IndividualTokenRefiner(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        heads_num,\n        depth,\n        mlp_width_ratio: float = 4.0,\n        mlp_drop_rate: float = 0.0,\n        act_type: str = \"silu\",\n        qk_norm: bool = False,\n        qk_norm_type: str = \"layer\",\n        qkv_bias: bool = True,\n        need_CA:bool=False,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):  \n        \n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.need_CA = need_CA\n        self.blocks = nn.ModuleList(\n            [\n                IndividualTokenRefinerBlock(\n                    hidden_size=hidden_size,\n                    heads_num=heads_num,\n                    mlp_width_ratio=mlp_width_ratio,\n                    mlp_drop_rate=mlp_drop_rate,\n                    act_type=act_type,\n                    qk_norm=qk_norm,\n                    qk_norm_type=qk_norm_type,\n                    qkv_bias=qkv_bias,\n                    need_CA=self.need_CA,\n                    **factory_kwargs,\n                )\n                for _ in range(depth)\n            ]\n        )\n\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        c: torch.LongTensor,\n        mask: Optional[torch.Tensor] = None,\n        y:torch.Tensor=None,\n    ):\n        self_attn_mask = None\n        if mask is not None:\n            batch_size = mask.shape[0]\n            seq_len = mask.shape[1]\n            mask = mask.to(x.device)\n            # batch_size x 1 x seq_len x seq_len\n            self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(\n                1, 1, seq_len, 1\n            )\n            # batch_size x 1 x seq_len x seq_len\n            self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)\n            # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num\n            self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()\n            # avoids self-attention weight being NaN for padding tokens\n            self_attn_mask[:, :, :, 0] = True\n        \n        \n        for block in self.blocks:\n            x = block(x, c, self_attn_mask,y)\n\n        return x\n\n\nclass SingleTokenRefiner(torch.nn.Module):\n    \"\"\"\n    A single token refiner block for llm text embedding refine.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels,\n        hidden_size,\n        heads_num,\n        depth,\n        mlp_width_ratio: float = 4.0,\n        mlp_drop_rate: float = 0.0,\n        act_type: str = \"silu\",\n        qk_norm: bool = False,\n        qk_norm_type: str = \"layer\",\n        qkv_bias: bool = True,\n        need_CA:bool=False,\n        attn_mode: str = \"torch\",\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.attn_mode = attn_mode\n        self.need_CA = need_CA\n        assert self.attn_mode == \"torch\", \"Only support 'torch' mode for token refiner.\"\n\n        self.input_embedder = nn.Linear(\n            in_channels, hidden_size, bias=True, **factory_kwargs\n        )\n        if self.need_CA:\n            self.input_embedder_CA = nn.Linear(\n            in_channels, hidden_size, bias=True, **factory_kwargs\n        )\n\n        act_layer = get_activation_layer(act_type)\n        # Build timestep embedding layer\n        self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)\n        # Build context embedding layer\n        self.c_embedder = TextProjection(\n            in_channels, hidden_size, act_layer, **factory_kwargs\n        )\n\n        self.individual_token_refiner = IndividualTokenRefiner(\n            hidden_size=hidden_size,\n            heads_num=heads_num,\n            depth=depth,\n            mlp_width_ratio=mlp_width_ratio,\n            mlp_drop_rate=mlp_drop_rate,\n            act_type=act_type,\n            qk_norm=qk_norm,\n            qk_norm_type=qk_norm_type,\n            qkv_bias=qkv_bias,\n            need_CA=need_CA,\n            **factory_kwargs,\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        t: torch.LongTensor,\n        mask: Optional[torch.LongTensor] = None,\n        y: torch.LongTensor=None,\n    ):\n        timestep_aware_representations = self.t_embedder(t)\n\n        if mask is None:\n            context_aware_representations = x.mean(dim=1)\n        else:\n            mask_float = mask.unsqueeze(-1)  # [b, s1, 1]\n            context_aware_representations = (x * mask_float).sum(\n                dim=1\n            ) / mask_float.sum(dim=1)\n        context_aware_representations = self.c_embedder(context_aware_representations)\n        c = timestep_aware_representations + context_aware_representations\n\n        x = self.input_embedder(x)\n        if self.need_CA:\n            y = self.input_embedder_CA(y)\n            x = self.individual_token_refiner(x, c, mask, y)\n        else:\n            x = self.individual_token_refiner(x, c, mask)\n\n        return x\n\n\nclass Qwen2Connector(torch.nn.Module):\n    def __init__(\n        self,\n        # biclip_dim=1024,\n        in_channels=3584,\n        hidden_size=4096,\n        heads_num=32,\n        depth=2,\n        need_CA=False,\n        device=None,\n        dtype=torch.bfloat16,\n    ):\n        super().__init__()\n        factory_kwargs = {\"device\": device, \"dtype\":dtype}\n\n        self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)\n        self.global_proj_out=nn.Linear(in_channels,768)\n\n        self.scale_factor = nn.Parameter(torch.zeros(1))\n        with torch.no_grad():\n            self.scale_factor.data += -(1 - 0.09)\n\n    def forward(self, x,t,mask):\n        mask_float = mask.unsqueeze(-1)  # [b, s1, 1]\n        x_mean = (x * mask_float).sum(\n                dim=1\n            ) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))\n\n        global_out=self.global_proj_out(x_mean)\n        encoder_hidden_states = self.S(x,t,mask)\n        return encoder_hidden_states,global_out\n"
  },
  {
    "path": "diffsynth/models/step1x_text_encoder.py",
    "content": "import torch\nfrom typing import Optional, Union\nfrom .qwen_image_text_encoder import QwenImageTextEncoder\nfrom ..core.device.npu_compatible_device import get_device_type, get_torch_device\n\n\nclass Step1xEditEmbedder(torch.nn.Module):\n    def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):\n        super().__init__()\n        self.max_length = max_length\n        self.dtype = dtype\n        self.device = device\n        \n        Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\\n\nHere are examples of how to transform or refine prompts:\n- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\\n\nPlease generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\nUser Prompt:'''\n\n        self.prefix = Qwen25VL_7b_PREFIX\n        self.model = model\n        self.processor = processor\n        \n    def model_forward(\n        self,\n        model: QwenImageTextEncoder,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        pixel_values: Optional[torch.Tensor] = None,\n        pixel_values_videos: Optional[torch.FloatTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        rope_deltas: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states\n        )\n\n        outputs = model.model(\n            input_ids=input_ids,\n            pixel_values=pixel_values,\n            pixel_values_videos=pixel_values_videos,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        return outputs.hidden_states\n        \n    def forward(self, caption, ref_images):\n        text_list = caption\n        embs = torch.zeros(\n            len(text_list),\n            self.max_length,\n            self.model.config.hidden_size,\n            dtype=torch.bfloat16,\n            device=get_torch_device().current_device(),\n        )\n        masks = torch.zeros(\n            len(text_list),\n            self.max_length,\n            dtype=torch.long,\n            device=get_torch_device().current_device(),\n        )\n\n        def split_string(s):\n            s = s.replace(\"“\", '\"').replace(\"”\", '\"').replace(\"'\", '''\"''')  # use english quotes\n            result = []\n            in_quotes = False\n            temp = \"\"\n\n            for idx,char in enumerate(s):\n                if char == '\"' and idx>155:\n                    temp += char\n                    if not in_quotes:\n                        result.append(temp)\n                        temp = \"\"\n\n                    in_quotes = not in_quotes\n                    continue\n                if in_quotes:\n                    if char.isspace():\n                        pass  # have space token\n\n                    result.append(\"“\" + char + \"”\")\n                else:\n                    temp += char\n\n            if temp:\n                result.append(temp)\n\n            return result\n\n        for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):\n\n            messages = [{\"role\": \"user\", \"content\": []}]\n\n            messages[0][\"content\"].append({\"type\": \"text\", \"text\": f\"{self.prefix}\"})\n\n            messages[0][\"content\"].append({\"type\": \"image\", \"image\": imgs})\n\n            # 再添加 text\n            messages[0][\"content\"].append({\"type\": \"text\", \"text\": f\"{txt}\"})\n\n            # Preparation for inference\n            text = self.processor.apply_chat_template(\n                messages, tokenize=False, add_generation_prompt=True, add_vision_id=True\n            )\n\n            image_inputs = [imgs]\n\n            inputs = self.processor(\n                text=[text],\n                images=image_inputs,\n                padding=True,\n                return_tensors=\"pt\",\n            )\n\n            old_inputs_ids = inputs.input_ids\n            text_split_list = split_string(text)\n\n            token_list = []\n            for text_each in text_split_list:\n                txt_inputs = self.processor(\n                    text=text_each,\n                    images=None,\n                    videos=None,\n                    padding=True,\n                    return_tensors=\"pt\",\n                )\n                token_each = txt_inputs.input_ids\n                if token_each[0][0] == 2073 and token_each[0][-1] == 854:\n                    token_each = token_each[:, 1:-1]\n                    token_list.append(token_each)\n                else:\n                    token_list.append(token_each)\n\n            new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())\n\n            new_txt_ids = new_txt_ids.to(old_inputs_ids.device)\n\n            idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]\n            idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]\n            inputs.input_ids = (\n                torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)\n                .unsqueeze(0)\n                .to(get_device_type())\n            )\n            inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())\n            outputs = self.model_forward(\n                self.model,\n                input_ids=inputs.input_ids,\n                attention_mask=inputs.attention_mask,\n                pixel_values=inputs.pixel_values.to(get_device_type()),\n                image_grid_thw=inputs.image_grid_thw.to(get_device_type()),\n                output_hidden_states=True,\n            )\n\n            emb = outputs[-1]\n\n            embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][\n                : self.max_length\n            ]\n\n            masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(\n                (min(self.max_length, emb.shape[1] - 217)),\n                dtype=torch.long,\n                device=get_torch_device().current_device(),\n            )\n\n        return embs, masks\n"
  },
  {
    "path": "diffsynth/models/wan_video_animate_adapter.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import Tuple, Optional, List\nfrom einops import rearrange\n\n\n\nMEMORY_LAYOUT = {\n    \"flash\": (\n        lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),\n        lambda x: x,\n    ),\n    \"torch\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n    \"vanilla\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n}\n\n\ndef attention(\n    q,\n    k,\n    v,\n    mode=\"torch\",\n    drop_rate=0,\n    attn_mask=None,\n    causal=False,\n    max_seqlen_q=None,\n    batch_size=1,\n):\n    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]\n\n    if mode == \"torch\":\n        if attn_mask is not None and attn_mask.dtype != torch.bool:\n            attn_mask = attn_mask.to(q.dtype)\n        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)\n\n    x = post_attn_layout(x)\n    b, s, a, d = x.shape\n    out = x.reshape(b, s, -1)\n    return out\n\n\nclass CausalConv1d(nn.Module):\n\n    def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode=\"replicate\", **kwargs):\n        super().__init__()\n\n        self.pad_mode = pad_mode\n        padding = (kernel_size - 1, 0)  # T\n        self.time_causal_padding = padding\n\n        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)\n\n    def forward(self, x):\n        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)\n        return self.conv(x)\n\n\n\nclass FaceEncoder(nn.Module):\n    def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n\n        self.num_heads = num_heads\n        self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)\n        self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.act = nn.SiLU()\n        self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)\n        self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)\n\n        self.out_proj = nn.Linear(1024, hidden_dim)\n        self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))\n\n    def forward(self, x):\n        \n        x = rearrange(x, \"b t c -> b c t\")\n        b, c, t = x.shape\n\n        x = self.conv1_local(x)\n        x = rearrange(x, \"b (n c) t -> (b n) t c\", n=self.num_heads)\n        \n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, \"b t c -> b c t\")\n        x = self.conv2(x)\n        x = rearrange(x, \"b c t -> b t c\")\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, \"b t c -> b c t\")\n        x = self.conv3(x)\n        x = rearrange(x, \"b c t -> b t c\")\n        x = self.norm3(x)\n        x = self.act(x)\n        x = self.out_proj(x)\n        x = rearrange(x, \"(b n) t c -> b t n c\", b=b)\n        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)\n        x = torch.cat([x, padding], dim=-2)\n        x_local = x.clone()\n\n        return x_local\n\n\n\nclass RMSNorm(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        elementwise_affine=True,\n        eps: float = 1e-6,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        Initialize the RMSNorm normalization layer.\n\n        Args:\n            dim (int): The dimension of the input tensor.\n            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n\n        Attributes:\n            eps (float): A small value added to the denominator for numerical stability.\n            weight (nn.Parameter): Learnable scaling parameter.\n\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))\n\n    def _norm(self, x):\n        \"\"\"\n        Apply the RMSNorm normalization to the input tensor.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The normalized tensor.\n\n        \"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass through the RMSNorm layer.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The output tensor after applying RMSNorm.\n\n        \"\"\"\n        output = self._norm(x.float()).type_as(x)\n        if hasattr(self, \"weight\"):\n            output = output * self.weight\n        return output\n\n\ndef get_norm_layer(norm_layer):\n    \"\"\"\n    Get the normalization layer.\n\n    Args:\n        norm_layer (str): The type of normalization layer.\n\n    Returns:\n        norm_layer (nn.Module): The normalization layer.\n    \"\"\"\n    if norm_layer == \"layer\":\n        return nn.LayerNorm\n    elif norm_layer == \"rms\":\n        return RMSNorm\n    else:\n        raise NotImplementedError(f\"Norm layer {norm_layer} is not implemented\")\n\n\nclass FaceAdapter(nn.Module):\n    def __init__(\n        self,\n        hidden_dim: int,\n        heads_num: int,\n        qk_norm: bool = True,\n        qk_norm_type: str = \"rms\",\n        num_adapter_layers: int = 1,\n        dtype=None,\n        device=None,\n    ):\n\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n        self.hidden_size = hidden_dim\n        self.heads_num = heads_num\n        self.fuser_blocks = nn.ModuleList(\n            [\n                FaceBlock(\n                    self.hidden_size,\n                    self.heads_num,\n                    qk_norm=qk_norm,\n                    qk_norm_type=qk_norm_type,\n                    **factory_kwargs,\n                )\n                for _ in range(num_adapter_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        motion_embed: torch.Tensor,\n        idx: int,\n        freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,\n        freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)\n\n\n\nclass FaceBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        heads_num: int,\n        qk_norm: bool = True,\n        qk_norm_type: str = \"rms\",\n        qk_scale: float = None,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n\n        self.deterministic = False\n        self.hidden_size = hidden_size\n        self.heads_num = heads_num\n        head_dim = hidden_size // heads_num\n        self.scale = qk_scale or head_dim**-0.5\n       \n        self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)\n        self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)\n\n        self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)\n\n        qk_norm_layer = get_norm_layer(qk_norm_type)\n        self.q_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()\n        )\n        self.k_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()\n        )\n\n        self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        motion_vec: torch.Tensor,\n        motion_mask: Optional[torch.Tensor] = None,\n        use_context_parallel=False,\n    ) -> torch.Tensor:\n        \n        B, T, N, C = motion_vec.shape\n        T_comp = T\n\n        x_motion = self.pre_norm_motion(motion_vec)\n        x_feat = self.pre_norm_feat(x)\n\n        kv = self.linear1_kv(x_motion)\n        q = self.linear1_q(x_feat)\n\n        k, v = rearrange(kv, \"B L N (K H D) -> K B L N H D\", K=2, H=self.heads_num)\n        q = rearrange(q, \"B S (H D) -> B S H D\", H=self.heads_num)\n\n        # Apply QK-Norm if needed.\n        q = self.q_norm(q).to(v)\n        k = self.k_norm(k).to(v)\n\n        k = rearrange(k, \"B L N H D -> (B L) H N D\")  \n        v = rearrange(v, \"B L N H D -> (B L) H N D\") \n\n        q = rearrange(q, \"B (L S) H D -> (B L) H S D\", L=T_comp)  \n        # Compute attention.\n        attn = F.scaled_dot_product_attention(q, k, v)\n\n        attn = rearrange(attn, \"(B L) H S D -> B (L S) (H D)\", L=T_comp)\n\n        output = self.linear2(attn)\n\n        if motion_mask is not None:\n            output = output * rearrange(motion_mask, \"B T H W -> B (T H W)\").unsqueeze(-1)\n\n        return output\n\n\n\ndef custom_qr(input_tensor):\n    original_dtype = input_tensor.dtype\n    if original_dtype == torch.bfloat16:\n        q, r = torch.linalg.qr(input_tensor.to(torch.float32))\n        return q.to(original_dtype), r.to(original_dtype)\n    return torch.linalg.qr(input_tensor)\n\ndef fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):\n\treturn F.leaky_relu(input + bias, negative_slope) * scale\n\n\ndef upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):\n\t_, minor, in_h, in_w = input.shape\n\tkernel_h, kernel_w = kernel.shape\n\n\tout = input.view(-1, minor, in_h, 1, in_w, 1)\n\tout = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])\n\tout = out.view(-1, minor, in_h * up_y, in_w * up_x)\n\n\tout = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])\n\tout = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),\n\t\t  max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]\n\n\tout = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])\n\tw = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\n\tout = F.conv2d(out, w)\n\tout = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\n\t\t\t\t\t  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )\n\treturn out[:, :, ::down_y, ::down_x]\n\n\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\n\treturn upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])\n\n\ndef make_kernel(k):\n\tk = torch.tensor(k, dtype=torch.float32)\n\tif k.ndim == 1:\n\t\tk = k[None, :] * k[:, None]\n\tk /= k.sum()\n\treturn k\n\n\nclass FusedLeakyReLU(nn.Module):\n\tdef __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):\n\t\tsuper().__init__()\n\t\tself.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))\n\t\tself.negative_slope = negative_slope\n\t\tself.scale = scale\n\n\tdef forward(self, input):\n\t\tout = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\n\t\treturn out\n\n\nclass Blur(nn.Module):\n\tdef __init__(self, kernel, pad, upsample_factor=1):\n\t\tsuper().__init__()\n\n\t\tkernel = make_kernel(kernel)\n\n\t\tif upsample_factor > 1:\n\t\t\tkernel = kernel * (upsample_factor ** 2)\n\n\t\tself.kernel = torch.nn.Parameter(kernel)\n\n\t\tself.pad = pad\n\n\tdef forward(self, input):\n\t\treturn upfirdn2d(input, self.kernel, pad=self.pad)\n\n\nclass ScaledLeakyReLU(nn.Module):\n\tdef __init__(self, negative_slope=0.2):\n\t\tsuper().__init__()\n\n\t\tself.negative_slope = negative_slope\n\n\tdef forward(self, input):\n\t\treturn F.leaky_relu(input, negative_slope=self.negative_slope)\n\n\nclass EqualConv2d(nn.Module):\n\tdef __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):\n\t\tsuper().__init__()\n\n\t\tself.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))\n\t\tself.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)\n\n\t\tself.stride = stride\n\t\tself.padding = padding\n\n\t\tif bias:\n\t\t\tself.bias = nn.Parameter(torch.zeros(out_channel))\n\t\telse:\n\t\t\tself.bias = None\n\n\tdef forward(self, input):\n\n\t\treturn F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)\n\n\tdef __repr__(self):\n\t\treturn (\n\t\t\tf'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'\n\t\t\tf' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'\n\t\t)\n\n\nclass EqualLinear(nn.Module):\n\tdef __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):\n\t\tsuper().__init__()\n\n\t\tself.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n\t\tif bias:\n\t\t\tself.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\t\telse:\n\t\t\tself.bias = None\n\n\t\tself.activation = activation\n\n\t\tself.scale = (1 / math.sqrt(in_dim)) * lr_mul\n\t\tself.lr_mul = lr_mul\n\n\tdef forward(self, input):\n\n\t\tif self.activation:\n\t\t\tout = F.linear(input, self.weight * self.scale)\n\t\t\tout = fused_leaky_relu(out, self.bias * self.lr_mul)\n\t\telse:\n\t\t\tout = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)\n\n\t\treturn out\n\n\tdef __repr__(self):\n\t\treturn (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')\n\n\nclass ConvLayer(nn.Sequential):\n\tdef __init__(\n\t\t\tself,\n\t\t\tin_channel,\n\t\t\tout_channel,\n\t\t\tkernel_size,\n\t\t\tdownsample=False,\n\t\t\tblur_kernel=[1, 3, 3, 1],\n\t\t\tbias=True,\n\t\t\tactivate=True,\n\t):\n\t\tlayers = []\n\n\t\tif downsample:\n\t\t\tfactor = 2\n\t\t\tp = (len(blur_kernel) - factor) + (kernel_size - 1)\n\t\t\tpad0 = (p + 1) // 2\n\t\t\tpad1 = p // 2\n\n\t\t\tlayers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n\t\t\tstride = 2\n\t\t\tself.padding = 0\n\n\t\telse:\n\t\t\tstride = 1\n\t\t\tself.padding = kernel_size // 2\n\n\t\tlayers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,\n\t\t\t\t\t\t\t\t  bias=bias and not activate))\n\n\t\tif activate:\n\t\t\tif bias:\n\t\t\t\tlayers.append(FusedLeakyReLU(out_channel))\n\t\t\telse:\n\t\t\t\tlayers.append(ScaledLeakyReLU(0.2))\n\n\t\tsuper().__init__(*layers)\n\n\nclass ResBlock(nn.Module):\n\tdef __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):\n\t\tsuper().__init__()\n\n\t\tself.conv1 = ConvLayer(in_channel, in_channel, 3)\n\t\tself.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)\n\n\t\tself.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)\n\n\tdef forward(self, input):\n\t\tout = self.conv1(input)\n\t\tout = self.conv2(out)\n\n\t\tskip = self.skip(input)\n\t\tout = (out + skip) / math.sqrt(2)\n\n\t\treturn out\n\n\nclass EncoderApp(nn.Module):\n\tdef __init__(self, size, w_dim=512):\n\t\tsuper(EncoderApp, self).__init__()\n\n\t\tchannels = {\n\t\t\t4: 512,\n\t\t\t8: 512,\n\t\t\t16: 512,\n\t\t\t32: 512,\n\t\t\t64: 256,\n\t\t\t128: 128,\n\t\t\t256: 64,\n\t\t\t512: 32,\n\t\t\t1024: 16\n\t\t}\n\n\t\tself.w_dim = w_dim\n\t\tlog_size = int(math.log(size, 2))\n\n\t\tself.convs = nn.ModuleList()\n\t\tself.convs.append(ConvLayer(3, channels[size], 1))\n\n\t\tin_channel = channels[size]\n\t\tfor i in range(log_size, 2, -1):\n\t\t\tout_channel = channels[2 ** (i - 1)]\n\t\t\tself.convs.append(ResBlock(in_channel, out_channel))\n\t\t\tin_channel = out_channel\n\n\t\tself.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))\n\n\tdef forward(self, x):\n\n\t\tres = []\n\t\th = x\n\t\tfor conv in self.convs:\n\t\t\th = conv(h)\n\t\t\tres.append(h)\n\n\t\treturn res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]\n\n\nclass Encoder(nn.Module):\n\tdef __init__(self, size, dim=512, dim_motion=20):\n\t\tsuper(Encoder, self).__init__()\n\n\t\t# appearance netmork\n\t\tself.net_app = EncoderApp(size, dim)\n\n\t\t# motion network\n\t\tfc = [EqualLinear(dim, dim)]\n\t\tfor i in range(3):\n\t\t\tfc.append(EqualLinear(dim, dim))\n\n\t\tfc.append(EqualLinear(dim, dim_motion))\n\t\tself.fc = nn.Sequential(*fc)\n\n\tdef enc_app(self, x):\n\t\th_source = self.net_app(x)\n\t\treturn h_source\n\n\tdef enc_motion(self, x):\n\t\th, _ = self.net_app(x)\n\t\th_motion = self.fc(h)\n\t\treturn h_motion\n\n\nclass Direction(nn.Module):\n    def __init__(self, motion_dim):\n        super(Direction, self).__init__()\n        self.weight = nn.Parameter(torch.randn(512, motion_dim))\n\n    def forward(self, input):\n\n        weight = self.weight + 1e-8\n        Q, R = custom_qr(weight)\n        if input is None:\n            return Q\n        else:\n            input_diag = torch.diag_embed(input)  # alpha, diagonal matrix\n            out = torch.matmul(input_diag, Q.T)\n            out = torch.sum(out, dim=1)\n            return out\n\n\nclass Synthesis(nn.Module):\n    def __init__(self, motion_dim):\n        super(Synthesis, self).__init__()\n        self.direction = Direction(motion_dim)\n\n\nclass Generator(nn.Module):\n    def __init__(self, size, style_dim=512, motion_dim=20):\n        super().__init__()\n\n        self.enc = Encoder(size, style_dim, motion_dim)\n        self.dec = Synthesis(motion_dim)\n\n    def get_motion(self, img):\n        #motion_feat = self.enc.enc_motion(img)\n        motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)\n        motion = self.dec.direction(motion_feat)\n        return motion\n\n\nclass WanAnimateAdapter(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2))\n        self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)\n        self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5)\n        self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4)\n    \n    def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):\n        pose_latents = self.pose_patch_embedding(pose_latents)\n        x[:, :, 1:] += pose_latents\n        \n        b,c,T,h,w = face_pixel_values.shape\n        face_pixel_values = rearrange(face_pixel_values, \"b c t h w -> (b t) c h w\")\n\n        encode_bs = 8\n        face_pixel_values_tmp = []\n        for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):\n            face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))\n\n        motion_vec = torch.cat(face_pixel_values_tmp)\n        \n        motion_vec = rearrange(motion_vec, \"(b t) c -> b t c\", t=T)\n        motion_vec = self.face_encoder(motion_vec)\n\n        B, L, H, C = motion_vec.shape\n        pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)\n        motion_vec = torch.cat([pad_face, motion_vec], dim=1)\n        return x, motion_vec\n    \n    def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):\n        if block_idx % 5 == 0:\n            adapter_args = [x, motion_vec, motion_masks, False]\n            residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)\n            x = residual_out + x\n        return x\n"
  },
  {
    "path": "diffsynth/models/wan_video_camera_controller.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\nimport os\nfrom typing_extensions import Literal\n\nclass SimpleAdapter(nn.Module):\n    def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):\n        super(SimpleAdapter, self).__init__()\n\n        # Pixel Unshuffle: reduce spatial dimensions by a factor of 8\n        self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)\n\n        # Convolution: reduce spatial dimensions by a factor\n        #  of 2 (without overlap)\n        self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)\n\n        # Residual blocks for feature extraction\n        self.residual_blocks = nn.Sequential(\n            *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]\n        )\n\n    def forward(self, x):\n        # Reshape to merge the frame dimension into batch\n        bs, c, f, h, w = x.size()\n        x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)\n\n        # Pixel Unshuffle operation\n        x_unshuffled = self.pixel_unshuffle(x)\n\n        # Convolution operation\n        x_conv = self.conv(x_unshuffled)\n\n        # Feature extraction with residual blocks\n        out = self.residual_blocks(x_conv)\n\n        # Reshape to restore original bf dimension\n        out = out.view(bs, f, out.size(1), out.size(2), out.size(3))\n\n        # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames\n        out = out.permute(0, 2, 1, 3, 4)\n\n        return out\n    \n    def process_camera_coordinates(\n        self,\n        direction: Literal[\"Left\", \"Right\", \"Up\", \"Down\", \"LeftUp\", \"LeftDown\", \"RightUp\", \"RightDown\"],\n        length: int,\n        height: int,\n        width: int,\n        speed: float = 1/54,\n        origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)\n    ):\n        if origin is None:\n            origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)\n        coordinates = generate_camera_coordinates(direction, length, speed, origin)\n        plucker_embedding = process_pose_file(coordinates, width, height)\n        return plucker_embedding\n        \n    \n\nclass ResidualBlock(nn.Module):\n    def __init__(self, dim):\n        super(ResidualBlock, self).__init__()\n        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        residual = x\n        out = self.relu(self.conv1(x))\n        out = self.conv2(out)\n        out += residual\n        return out\n    \nclass Camera(object):\n    \"\"\"Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py\n    \"\"\"\n    def __init__(self, entry):\n        fx, fy, cx, cy = entry[1:5]\n        self.fx = fx\n        self.fy = fy\n        self.cx = cx\n        self.cy = cy\n        w2c_mat = np.array(entry[7:]).reshape(3, 4)\n        w2c_mat_4x4 = np.eye(4)\n        w2c_mat_4x4[:3, :] = w2c_mat\n        self.w2c_mat = w2c_mat_4x4\n        self.c2w_mat = np.linalg.inv(w2c_mat_4x4)\n\ndef get_relative_pose(cam_params):\n    \"\"\"Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py\n    \"\"\"\n    abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]\n    abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]\n    cam_to_origin = 0\n    target_cam_c2w = np.array([\n        [1, 0, 0, 0],\n        [0, 1, 0, -cam_to_origin],\n        [0, 0, 1, 0],\n        [0, 0, 0, 1]\n    ])\n    abs2rel = target_cam_c2w @ abs_w2cs[0]\n    ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]\n    ret_poses = np.array(ret_poses, dtype=np.float32)\n    return ret_poses\n\ndef custom_meshgrid(*args):\n    # torch>=2.0.0 only\n    return torch.meshgrid(*args, indexing='ij')\n\n\ndef ray_condition(K, c2w, H, W, device):\n    \"\"\"Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py\n    \"\"\"\n    # c2w: B, V, 4, 4\n    # K: B, V, 4\n\n    B = K.shape[0]\n\n    j, i = custom_meshgrid(\n        torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),\n        torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),\n    )\n    i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW]\n    j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW]\n\n    fx, fy, cx, cy = K.chunk(4, dim=-1)  # B,V, 1\n\n    zs = torch.ones_like(i)  # [B, HxW]\n    xs = (i - cx) / fx * zs\n    ys = (j - cy) / fy * zs\n    zs = zs.expand_as(ys)\n\n    directions = torch.stack((xs, ys, zs), dim=-1)  # B, V, HW, 3\n    directions = directions / directions.norm(dim=-1, keepdim=True)  # B, V, HW, 3\n\n    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)  # B, V, 3, HW\n    rays_o = c2w[..., :3, 3]  # B, V, 3\n    rays_o = rays_o[:, :, None].expand_as(rays_d)  # B, V, 3, HW\n    # c2w @ dirctions\n    rays_dxo = torch.linalg.cross(rays_o, rays_d)\n    plucker = torch.cat([rays_dxo, rays_d], dim=-1)\n    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)  # B, V, H, W, 6\n    # plucker = plucker.permute(0, 1, 4, 2, 3)\n    return plucker\n\n\ndef process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):\n    if return_poses:\n        return cam_params\n    else:\n        cam_params = [Camera(cam_param) for cam_param in cam_params]\n\n        sample_wh_ratio = width / height\n        pose_wh_ratio = original_pose_width / original_pose_height  # Assuming placeholder ratios, change as needed\n\n        if pose_wh_ratio > sample_wh_ratio:\n            resized_ori_w = height * pose_wh_ratio\n            for cam_param in cam_params:\n                cam_param.fx = resized_ori_w * cam_param.fx / width\n        else:\n            resized_ori_h = width / pose_wh_ratio\n            for cam_param in cam_params:\n                cam_param.fy = resized_ori_h * cam_param.fy / height\n\n        intrinsic = np.asarray([[cam_param.fx * width,\n                                cam_param.fy * height,\n                                cam_param.cx * width,\n                                cam_param.cy * height]\n                                for cam_param in cam_params], dtype=np.float32)\n\n        K = torch.as_tensor(intrinsic)[None]  # [1, 1, 4]\n        c2ws = get_relative_pose(cam_params)  # Assuming this function is defined elsewhere\n        c2ws = torch.as_tensor(c2ws)[None]  # [1, n_frame, 4, 4]\n        plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous()  # V, 6, H, W\n        plucker_embedding = plucker_embedding[None]\n        plucker_embedding = rearrange(plucker_embedding, \"b f c h w -> b f h w c\")[0]\n        return plucker_embedding\n\n\n\ndef generate_camera_coordinates(\n    direction: Literal[\"Left\", \"Right\", \"Up\", \"Down\", \"LeftUp\", \"LeftDown\", \"RightUp\", \"RightDown\", \"In\", \"Out\"],\n    length: int,\n    speed: float = 1/54,\n    origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)\n):\n    coordinates = [list(origin)]\n    while len(coordinates) < length:\n        coor = coordinates[-1].copy()\n        if \"Left\" in direction:\n            coor[9] += speed\n        if \"Right\" in direction:\n            coor[9] -= speed\n        if \"Up\" in direction:\n            coor[13] += speed\n        if \"Down\" in direction:\n            coor[13] -= speed\n        if \"In\" in direction:\n            coor[18] -= speed\n        if \"Out\" in direction:\n            coor[18] += speed\n        coordinates.append(coor)\n    return coordinates\n"
  },
  {
    "path": "diffsynth/models/wan_video_dit.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Tuple, Optional\nfrom einops import rearrange\nfrom .wan_video_camera_controller import SimpleAdapter\nfrom ..core.gradient import gradient_checkpoint_forward\nfrom .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer\n\ntry:\n    import flash_attn_interface\n    FLASH_ATTN_3_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_3_AVAILABLE = False\n\ntry:\n    import flash_attn\n    FLASH_ATTN_2_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_2_AVAILABLE = False\n\ntry:\n    from sageattention import sageattn\n    SAGE_ATTN_AVAILABLE = True\nexcept ModuleNotFoundError:\n    SAGE_ATTN_AVAILABLE = False\n    \n    \ndef flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):\n    if compatibility_mode:\n        q = rearrange(q, \"b s (n d) -> b n s d\", n=num_heads)\n        k = rearrange(k, \"b s (n d) -> b n s d\", n=num_heads)\n        v = rearrange(v, \"b s (n d) -> b n s d\", n=num_heads)\n        x = F.scaled_dot_product_attention(q, k, v)\n        x = rearrange(x, \"b n s d -> b s (n d)\", n=num_heads)\n    elif FLASH_ATTN_3_AVAILABLE:\n        q = rearrange(q, \"b s (n d) -> b s n d\", n=num_heads)\n        k = rearrange(k, \"b s (n d) -> b s n d\", n=num_heads)\n        v = rearrange(v, \"b s (n d) -> b s n d\", n=num_heads)\n        x = flash_attn_interface.flash_attn_func(q, k, v)\n        if isinstance(x,tuple):\n            x = x[0]\n        x = rearrange(x, \"b s n d -> b s (n d)\", n=num_heads)\n    elif FLASH_ATTN_2_AVAILABLE:\n        q = rearrange(q, \"b s (n d) -> b s n d\", n=num_heads)\n        k = rearrange(k, \"b s (n d) -> b s n d\", n=num_heads)\n        v = rearrange(v, \"b s (n d) -> b s n d\", n=num_heads)\n        x = flash_attn.flash_attn_func(q, k, v)\n        x = rearrange(x, \"b s n d -> b s (n d)\", n=num_heads)\n    elif SAGE_ATTN_AVAILABLE:\n        q = rearrange(q, \"b s (n d) -> b n s d\", n=num_heads)\n        k = rearrange(k, \"b s (n d) -> b n s d\", n=num_heads)\n        v = rearrange(v, \"b s (n d) -> b n s d\", n=num_heads)\n        x = sageattn(q, k, v)\n        x = rearrange(x, \"b n s d -> b s (n d)\", n=num_heads)\n    else:\n        q = rearrange(q, \"b s (n d) -> b n s d\", n=num_heads)\n        k = rearrange(k, \"b s (n d) -> b n s d\", n=num_heads)\n        v = rearrange(v, \"b s (n d) -> b n s d\", n=num_heads)\n        x = F.scaled_dot_product_attention(q, k, v)\n        x = rearrange(x, \"b n s d -> b s (n d)\", n=num_heads)\n    return x\n\n\ndef modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):\n    return (x * (1 + scale) + shift)\n\n\ndef sinusoidal_embedding_1d(dim, position):\n    sinusoid = torch.outer(position.type(torch.float64), torch.pow(\n        10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))\n    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)\n    return x.to(position.dtype)\n\n\ndef precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):\n    # 3d rope precompute\n    f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)\n    h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)\n    w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)\n    return f_freqs_cis, h_freqs_cis, w_freqs_cis\n\n\ndef precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):\n    # 1d rope precompute\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)\n                   [: (dim // 2)].double() / dim))\n    freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n    return freqs_cis\n\n\ndef rope_apply(x, freqs, num_heads):\n    x = rearrange(x, \"b s (n d) -> b s n d\", n=num_heads)\n    x_out = torch.view_as_complex(x.to(torch.float64).reshape(\n        x.shape[0], x.shape[1], x.shape[2], -1, 2))\n    freqs = freqs.to(torch.complex64) if freqs.device.type == \"npu\" else freqs\n    x_out = torch.view_as_real(x_out * freqs).flatten(2)\n    return x_out.to(x.dtype)\n\n\ndef set_to_torch_norm(models):\n    for model in models:\n        for module in model.modules():\n            if isinstance(module, RMSNorm):\n                module.use_torch_norm = True\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n        self.use_torch_norm = False\n        self.normalized_shape = (dim,)\n\n    def norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        dtype = x.dtype\n        if self.use_torch_norm:\n            return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)\n        else:        \n            return self.norm(x.float()).to(dtype) * self.weight\n\n\nclass AttentionModule(nn.Module):\n    def __init__(self, num_heads):\n        super().__init__()\n        self.num_heads = num_heads\n        \n    def forward(self, q, k, v):\n        x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)\n        return x\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n        \n        self.attn = AttentionModule(self.num_heads)\n\n    def forward(self, x, freqs):\n        q = self.norm_q(self.q(x))\n        k = self.norm_k(self.k(x))\n        v = self.v(x)\n        q = rope_apply(q, freqs, self.num_heads)\n        k = rope_apply(k, freqs, self.num_heads)\n        x = self.attn(q, k, v)\n        return self.o(x)\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n        self.has_image_input = has_image_input\n        if has_image_input:\n            self.k_img = nn.Linear(dim, dim)\n            self.v_img = nn.Linear(dim, dim)\n            self.norm_k_img = RMSNorm(dim, eps=eps)\n            \n        self.attn = AttentionModule(self.num_heads)\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor):\n        if self.has_image_input:\n            img = y[:, :257]\n            ctx = y[:, 257:]\n        else:\n            ctx = y\n        q = self.norm_q(self.q(x))\n        k = self.norm_k(self.k(ctx))\n        v = self.v(ctx)\n        x = self.attn(q, k, v)\n        if self.has_image_input:\n            k_img = self.norm_k_img(self.k_img(img))\n            v_img = self.v_img(img)\n            y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)\n            x = x + y\n        return self.o(x)\n\n\nclass GateModule(nn.Module):\n    def __init__(self,):\n        super().__init__()\n\n    def forward(self, x, gate, residual):\n        return x + gate * residual\n\nclass DiTBlock(nn.Module):\n    def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.ffn_dim = ffn_dim\n\n        self.self_attn = SelfAttention(dim, num_heads, eps)\n        self.cross_attn = CrossAttention(\n            dim, num_heads, eps, has_image_input=has_image_input)\n        self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)\n        self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)\n        self.norm3 = nn.LayerNorm(dim, eps=eps)\n        self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(\n            approximate='tanh'), nn.Linear(ffn_dim, dim))\n        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n        self.gate = GateModule()\n\n    def forward(self, x, context, t_mod, freqs):\n        has_seq = len(t_mod.shape) == 4\n        chunk_dim = 2 if has_seq else 1\n        # msa: multi-head self-attention  mlp: multi-layer perceptron\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)\n        if has_seq:\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n                shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),\n                shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),\n            )\n        input_x = modulate(self.norm1(x), shift_msa, scale_msa)\n        x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))\n        x = x + self.cross_attn(self.norm3(x), context)\n        input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)\n        x = self.gate(x, gate_mlp, self.ffn(input_x))\n        return x\n\n\nclass MLP(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, has_pos_emb=False):\n        super().__init__()\n        self.proj = torch.nn.Sequential(\n            nn.LayerNorm(in_dim),\n            nn.Linear(in_dim, in_dim),\n            nn.GELU(),\n            nn.Linear(in_dim, out_dim),\n            nn.LayerNorm(out_dim)\n        )\n        self.has_pos_emb = has_pos_emb\n        if has_pos_emb:\n            self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))\n\n    def forward(self, x):\n        if self.has_pos_emb:\n            x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)\n        return self.proj(x)\n\n\nclass Head(nn.Module):\n    def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):\n        super().__init__()\n        self.dim = dim\n        self.patch_size = patch_size\n        self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)\n        self.head = nn.Linear(dim, out_dim * math.prod(patch_size))\n        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)\n\n    def forward(self, x, t_mod):\n        if len(t_mod.shape) == 3:\n            shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)\n            x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))\n        else:\n            shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)\n            x = (self.head(self.norm(x) * (1 + scale) + shift))\n        return x\n\n\ndef wantodance_torch_dfs(model: nn.Module, parent_name='root'):\n    module_names, modules = [], []\n    current_name = parent_name if parent_name else 'root'\n    module_names.append(current_name)\n    modules.append(model)\n    for name, child in model.named_children():\n        if parent_name:\n            child_name = f'{parent_name}.{name}'\n        else:\n            child_name = name\n        child_modules, child_names = wantodance_torch_dfs(child, child_name)\n        module_names += child_names\n        modules += child_modules\n    return modules, module_names\n\n\nclass WanToDanceInjector(nn.Module):\n    def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):\n        super().__init__()\n        self.injected_block_id = {}\n        injector_id = 0\n        for mod_name, mod in zip(all_modules_names, all_modules):\n            if isinstance(mod, DiTBlock):\n                for inject_id in inject_layer:\n                    if f'root.transformer_blocks.{inject_id}' == mod_name:\n                        self.injected_block_id[inject_id] = injector_id\n                        injector_id += 1\n\n        self.injector = nn.ModuleList(\n            [\n                CrossAttention(\n                    dim=dim,\n                    num_heads=num_heads,\n                )\n                for _ in range(injector_id)\n            ]\n        )\n        self.injector_pre_norm_feat = nn.ModuleList(\n            [\n                nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)\n                for _ in range(injector_id)\n            ]\n        )\n        self.injector_pre_norm_vec = nn.ModuleList(\n            [\n                nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)\n                for _ in range(injector_id)\n            ]\n        )\n\n\nclass WanModel(torch.nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        in_dim: int,\n        ffn_dim: int,\n        out_dim: int,\n        text_dim: int,\n        freq_dim: int,\n        eps: float,\n        patch_size: Tuple[int, int, int],\n        num_heads: int,\n        num_layers: int,\n        has_image_input: bool,\n        has_image_pos_emb: bool = False,\n        has_ref_conv: bool = False,\n        add_control_adapter: bool = False,\n        in_dim_control_adapter: int = 24,\n        seperated_timestep: bool = False,\n        require_vae_embedding: bool = True,\n        require_clip_embedding: bool = True,\n        fuse_vae_embedding_in_latents: bool = False,\n        wantodance_enable_music_inject: bool = False,\n        wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],\n        wantodance_enable_refimage: bool = False,\n        wantodance_enable_refface: bool = False,\n        wantodance_enable_global: bool = False,\n        wantodance_enable_dynamicfps: bool = False,\n        wantodance_enable_unimodel: bool = False,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.in_dim = in_dim\n        self.freq_dim = freq_dim\n        self.has_image_input = has_image_input\n        self.patch_size = patch_size\n        self.seperated_timestep = seperated_timestep\n        self.require_vae_embedding = require_vae_embedding\n        self.require_clip_embedding = require_clip_embedding\n        self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents\n\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.text_embedding = nn.Sequential(\n            nn.Linear(text_dim, dim),\n            nn.GELU(approximate='tanh'),\n            nn.Linear(dim, dim)\n        )\n        self.time_embedding = nn.Sequential(\n            nn.Linear(freq_dim, dim),\n            nn.SiLU(),\n            nn.Linear(dim, dim)\n        )\n        self.time_projection = nn.Sequential(\n            nn.SiLU(), nn.Linear(dim, dim * 6))\n        self.blocks = nn.ModuleList([\n            DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)\n            for _ in range(num_layers)\n        ])\n        self.head = Head(dim, out_dim, patch_size, eps)\n        head_dim = dim // num_heads\n\n        if wantodance_enable_dynamicfps or wantodance_enable_unimodel:\n            end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350\n            self.freqs = precompute_freqs_cis_3d(head_dim, end=end)\n        else:\n            self.freqs = precompute_freqs_cis_3d(head_dim)\n\n        if has_image_input:\n            self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)  # clip_feature_dim = 1280\n        if has_ref_conv:\n            self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))\n        self.has_image_pos_emb = has_image_pos_emb\n        self.has_ref_conv = has_ref_conv\n        if add_control_adapter:\n            self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])\n        else:\n            self.control_adapter = None\n\n        self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,\n                                wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,\n                                wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)\n\n    def prepare_wantodance(\n        self,\n        in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,\n        wantodance_enable_music_inject: bool = False,\n        wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],\n        wantodance_enable_refimage: bool = False,\n        wantodance_enable_refface: bool = False,\n        wantodance_enable_global: bool = False,\n        wantodance_enable_dynamicfps: bool = False,\n        wantodance_enable_unimodel: bool = False,\n    ):\n        if wantodance_enable_music_inject:\n            all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name=\"root.transformer_blocks\")\n            self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)\n        if wantodance_enable_refimage:\n            self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb)  # clip_feature_dim = 1280\n        if wantodance_enable_refface:\n            self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb)  # clip_feature_dim = 1280\n        if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:\n            music_feature_dim = 35\n            ff_size = 1024\n            dropout = 0.1\n            latent_dim = 256\n            nhead = 4\n            activation = F.gelu\n            rotary = WanToDanceRotaryEmbedding(dim=latent_dim)\n            self.music_projection = nn.Linear(music_feature_dim, latent_dim)\n            self.music_encoder = nn.Sequential()\n            for _ in range(2):\n                self.music_encoder.append(\n                    WanToDanceMusicEncoderLayer(\n                        d_model=latent_dim,\n                        nhead=nhead,\n                        dim_feedforward=ff_size,\n                        dropout=dropout,\n                        activation=activation,\n                        batch_first=True,\n                        rotary=rotary,\n                        device='cuda',\n                    )\n                )\n        if wantodance_enable_unimodel:\n            self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        if wantodance_enable_unimodel:\n            self.head_global = Head(dim, out_dim, patch_size, eps)\n        self.wantodance_enable_music_inject = wantodance_enable_music_inject\n        self.wantodance_enable_refimage = wantodance_enable_refimage\n        self.wantodance_enable_refface = wantodance_enable_refface\n        self.wantodance_enable_global = wantodance_enable_global\n        self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps\n        self.wantodance_enable_unimodel = wantodance_enable_unimodel\n\n    def wantodance_after_transformer_block(self, block_idx, hidden_states):\n        if self.wantodance_enable_music_inject:\n            if block_idx in self.music_injector.injected_block_id.keys():\n                audio_attn_id = self.music_injector.injected_block_id[block_idx]\n                audio_emb = self.merged_audio_emb  # b f n c\n                num_frames = audio_emb.shape[1]\n                input_hidden_states = hidden_states.clone()  # b (f h w) c\n                input_hidden_states = rearrange(input_hidden_states, \"b (t n) c -> (b t) n c\", t=num_frames)\n                attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)\n                audio_emb = rearrange(audio_emb, \"b t c -> (b t) 1 c\", t=num_frames)\n                attn_audio_emb = audio_emb\n                residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)\n                residual_out = rearrange(residual_out, \"(b t) n c -> b (t n) c\", t=num_frames)\n                hidden_states = hidden_states + residual_out\n        return hidden_states\n\n    def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):\n        if enable_wantodance_global:\n            x = self.patch_embedding_global(x)\n        else:\n            x = self.patch_embedding(x)\n        if self.control_adapter is not None and control_camera_latents_input is not None:\n            y_camera = self.control_adapter(control_camera_latents_input)\n            x = [u + v for u, v in zip(x, y_camera)]\n            x = x[0].unsqueeze(0)\n        return x\n\n    def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):\n        return rearrange(\n            x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',\n            f=grid_size[0], h=grid_size[1], w=grid_size[2], \n            x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]\n        )\n\n    def forward(self,\n                x: torch.Tensor,\n                timestep: torch.Tensor,\n                context: torch.Tensor,\n                clip_feature: Optional[torch.Tensor] = None,\n                y: Optional[torch.Tensor] = None,\n                use_gradient_checkpointing: bool = False,\n                use_gradient_checkpointing_offload: bool = False,\n                **kwargs,\n                ):\n        t = self.time_embedding(\n            sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))\n        t_mod = self.time_projection(t).unflatten(1, (6, self.dim))\n        context = self.text_embedding(context)\n        \n        if self.has_image_input:\n            x = torch.cat([x, y], dim=1)  # (b, c_x + c_y, f, h, w)\n            clip_embdding = self.img_emb(clip_feature)\n            context = torch.cat([clip_embdding, context], dim=1)\n        \n        x, (f, h, w) = self.patchify(x)\n        \n        freqs = torch.cat([\n            self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n            self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n            self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n        ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)\n\n        for block in self.blocks:\n            if self.training:\n                x = gradient_checkpoint_forward(\n                    block,\n                    use_gradient_checkpointing,\n                    use_gradient_checkpointing_offload,\n                    x, context, t_mod, freqs\n                )\n            else:\n                x = block(x, context, t_mod, freqs)\n\n        x = self.head(x, t)\n        x = self.unpatchify(x, (f, h, w))\n        return x\n"
  },
  {
    "path": "diffsynth/models/wan_video_dit_s2v.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Tuple\nfrom .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d\nfrom ..core.gradient import gradient_checkpoint_forward\n\n\ndef torch_dfs(model: nn.Module, parent_name='root'):\n    module_names, modules = [], []\n    current_name = parent_name if parent_name else 'root'\n    module_names.append(current_name)\n    modules.append(model)\n\n    for name, child in model.named_children():\n        if parent_name:\n            child_name = f'{parent_name}.{name}'\n        else:\n            child_name = name\n        child_modules, child_names = torch_dfs(child, child_name)\n        module_names += child_names\n        modules += child_modules\n    return modules, module_names\n\n\ndef rope_precompute(x, grid_sizes, freqs, start=None):\n    b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2\n\n    # split freqs\n    if type(freqs) is list:\n        trainable_freqs = freqs[1]\n        freqs = freqs[0]\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n\n    # loop over samples\n    output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))\n    seq_bucket = [0]\n    if not type(grid_sizes) is list:\n        grid_sizes = [grid_sizes]\n    for g in grid_sizes:\n        if not type(g) is list:\n            g = [torch.zeros_like(g), g]\n        batch_size = g[0].shape[0]\n        for i in range(batch_size):\n            if start is None:\n                f_o, h_o, w_o = g[0][i]\n            else:\n                f_o, h_o, w_o = start[i]\n\n            f, h, w = g[1][i]\n            t_f, t_h, t_w = g[2][i]\n            seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o\n            seq_len = int(seq_f * seq_h * seq_w)\n            if seq_len > 0:\n                if t_f > 0:\n                    factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()\n                    # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())\n                    if f_o >= 0:\n                        f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()\n                    else:\n                        f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()\n                    h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()\n                    w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()\n\n                    assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0\n                    freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()\n                    freqs_0 = freqs_0.view(seq_f, 1, 1, -1)\n\n                    freqs_i = torch.cat(\n                        [\n                            freqs_0.expand(seq_f, seq_h, seq_w, -1),\n                            freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),\n                            freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),\n                        ],\n                        dim=-1\n                    ).reshape(seq_len, 1, -1)\n                elif t_f < 0:\n                    freqs_i = trainable_freqs.unsqueeze(1)\n                # apply rotary embedding\n                output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i\n        seq_bucket.append(seq_bucket[-1] + seq_len)\n    return output\n\n\nclass CausalConv1d(nn.Module):\n\n    def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):\n        super().__init__()\n\n        self.pad_mode = pad_mode\n        padding = (kernel_size - 1, 0)  # T\n        self.time_causal_padding = padding\n\n        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)\n\n    def forward(self, x):\n        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)\n        return self.conv(x)\n\n\nclass MotionEncoder_tc(nn.Module):\n\n    def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n\n        self.num_heads = num_heads\n        self.need_global = need_global\n        self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)\n        if need_global:\n            self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)\n        self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.act = nn.SiLU()\n        self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)\n        self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)\n\n        if need_global:\n            self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)\n\n        self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))\n\n    def forward(self, x):\n        x = rearrange(x, 'b t c -> b c t')\n        x_ori = x.clone()\n        b, c, t = x.shape\n        x = self.conv1_local(x)\n        x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)\n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv2(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv3(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm3(x)\n        x = self.act(x)\n        x = rearrange(x, '(b n) t c -> b t n c', b=b)\n        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)\n        x = torch.cat([x, padding], dim=-2)\n        x_local = x.clone()\n\n        if not self.need_global:\n            return x_local\n\n        x = self.conv1_global(x_ori)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv2(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv3(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm3(x)\n        x = self.act(x)\n        x = self.final_linear(x)\n        x = rearrange(x, '(b n) t c -> b t n c', b=b)\n\n        return x, x_local\n\n\nclass FramePackMotioner(nn.Module):\n\n    def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode=\"drop\", *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))\n        self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))\n        self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))\n        self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)\n\n        self.inner_dim = inner_dim\n        self.num_heads = num_heads\n        self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)\n        self.drop_mode = drop_mode\n\n    def forward(self, motion_latents, add_last_motion=2):\n        motion_frames = motion_latents[0].shape[1]\n        mot = []\n        mot_remb = []\n        for m in motion_latents:\n            lat_height, lat_width = m.shape[2], m.shape[3]\n            padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)\n            overlap_frame = min(padd_lat.shape[1], m.shape[1])\n            if overlap_frame > 0:\n                padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]\n\n            if add_last_motion < 2 and self.drop_mode != \"drop\":\n                zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()\n                padd_lat[:, -zero_end_frame:] = 0\n\n            padd_lat = padd_lat.unsqueeze(0)\n            clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(\n                list(self.zip_frame_buckets)[::-1], dim=2\n            )  # 16, 2 ,1\n\n            # patchfy\n            clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)\n            clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)\n            clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)\n\n            if add_last_motion < 2 and self.drop_mode == \"drop\":\n                clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post\n                clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x\n\n            motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)\n\n            # rope\n            start_time_id = -(self.zip_frame_buckets[:1].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[0]\n            grid_sizes = [] if add_last_motion < 2 and self.drop_mode == \"drop\" else \\\n                        [\n                            [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                            torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),\n                            torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]\n                        ]\n\n            start_time_id = -(self.zip_frame_buckets[:2].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[1] // 2\n            grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == \"drop\" else \\\n            [\n                [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]\n            ]\n\n            start_time_id = -(self.zip_frame_buckets[:3].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[2] // 4\n            grid_sizes_4x = [\n                [\n                    torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                    torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),\n                    torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),\n                ]\n            ]\n\n            grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x\n\n            motion_rope_emb = rope_precompute(\n                motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),\n                grid_sizes,\n                self.freqs,\n                start=None\n            )\n\n            mot.append(motion_lat)\n            mot_remb.append(motion_rope_emb)\n        return mot, mot_remb\n\n\nclass AdaLayerNorm(nn.Module):\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        output_dim: int,\n        norm_eps: float = 1e-5,\n    ):\n        super().__init__()\n        self.silu = nn.SiLU()\n        self.linear = nn.Linear(embedding_dim, output_dim)\n        self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)\n\n    def forward(self, x, temb):\n        temb = self.linear(F.silu(temb))\n        shift, scale = temb.chunk(2, dim=1)\n        shift = shift[:, None, :]\n        scale = scale[:, None, :]\n        x = self.norm(x) * (1 + scale) + shift\n        return x\n\n\nclass AudioInjector_WAN(nn.Module):\n\n    def __init__(\n        self,\n        all_modules,\n        all_modules_names,\n        dim=2048,\n        num_heads=32,\n        inject_layer=[0, 27],\n        enable_adain=False,\n        adain_dim=2048,\n    ):\n        super().__init__()\n        self.injected_block_id = {}\n        audio_injector_id = 0\n        for mod_name, mod in zip(all_modules_names, all_modules):\n            if isinstance(mod, DiTBlock):\n                for inject_id in inject_layer:\n                    if f'transformer_blocks.{inject_id}' in mod_name:\n                        self.injected_block_id[inject_id] = audio_injector_id\n                        audio_injector_id += 1\n\n        self.injector = nn.ModuleList([CrossAttention(\n            dim=dim,\n            num_heads=num_heads,\n        ) for _ in range(audio_injector_id)])\n        self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(\n            dim,\n            elementwise_affine=False,\n            eps=1e-6,\n        ) for _ in range(audio_injector_id)])\n        self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(\n            dim,\n            elementwise_affine=False,\n            eps=1e-6,\n        ) for _ in range(audio_injector_id)])\n        if enable_adain:\n            self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])\n\n\nclass CausalAudioEncoder(nn.Module):\n\n    def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):\n        super().__init__()\n        self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)\n        weight = torch.ones((1, num_layers, 1, 1)) * 0.01\n\n        self.weights = torch.nn.Parameter(weight)\n        self.act = torch.nn.SiLU()\n\n    def forward(self, features):\n        # features B * num_layers * dim * video_length\n        weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))\n        weights_sum = weights.sum(dim=1, keepdims=True)\n        weighted_feat = ((features * weights) / weights_sum).sum(dim=1)  # b dim f\n        weighted_feat = weighted_feat.permute(0, 2, 1)  # b f dim\n        res = self.encoder(weighted_feat)  # b f n dim\n        return res  # b f n dim\n\n\nclass WanS2VDiTBlock(DiTBlock):\n\n    def forward(self, x, context, t_mod, seq_len_x, freqs):\n        t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)\n        # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.\n        t_mod = [\n            torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)\n            for element in t_mod\n        ]\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod\n        input_x = modulate(self.norm1(x), shift_msa, scale_msa)\n        x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))\n        x = x + self.cross_attn(self.norm3(x), context)\n        input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)\n        x = self.gate(x, gate_mlp, self.ffn(input_x))\n        return x\n\n\nclass WanS2VModel(torch.nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        in_dim: int,\n        ffn_dim: int,\n        out_dim: int,\n        text_dim: int,\n        freq_dim: int,\n        eps: float,\n        patch_size: Tuple[int, int, int],\n        num_heads: int,\n        num_layers: int,\n        cond_dim: int,\n        audio_dim: int,\n        num_audio_token: int,\n        enable_adain: bool = True,\n        audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],\n        zero_timestep: bool = True,\n        add_last_motion: bool = True,\n        framepack_drop_mode: str = \"padd\",\n        fuse_vae_embedding_in_latents: bool = True,\n        require_vae_embedding: bool = False,\n        seperated_timestep: bool = False,\n        require_clip_embedding: bool = False,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.in_dim = in_dim\n        self.freq_dim = freq_dim\n        self.patch_size = patch_size\n        self.num_heads = num_heads\n        self.enbale_adain = enable_adain\n        self.add_last_motion = add_last_motion\n        self.zero_timestep = zero_timestep\n        self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents\n        self.require_vae_embedding = require_vae_embedding\n        self.seperated_timestep = seperated_timestep\n        self.require_clip_embedding = require_clip_embedding\n\n        self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))\n        self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))\n\n        self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])\n        self.head = Head(dim, out_dim, patch_size, eps)\n        self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)\n\n        self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)\n        all_modules, all_modules_names = torch_dfs(self.blocks, parent_name=\"root.transformer_blocks\")\n        self.audio_injector = AudioInjector_WAN(\n            all_modules,\n            all_modules_names,\n            dim=dim,\n            num_heads=num_heads,\n            inject_layer=audio_inject_layers,\n            enable_adain=enable_adain,\n            adain_dim=dim,\n        )\n        self.trainable_cond_mask = nn.Embedding(3, dim)\n        self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)\n\n    def patchify(self, x: torch.Tensor):\n        grid_size = x.shape[2:]\n        x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()\n        return x, grid_size  # x, grid_size: (f, h, w)\n\n    def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):\n        return rearrange(\n            x,\n            'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',\n            f=grid_size[0],\n            h=grid_size[1],\n            w=grid_size[2],\n            x=self.patch_size[0],\n            y=self.patch_size[1],\n            z=self.patch_size[2]\n        )\n\n    def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):\n        flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)\n        if drop_motion_frames:\n            return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]\n        else:\n            return flattern_mot, mot_remb\n\n    def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):\n        # inject the motion frames token to the hidden states\n        mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)\n        if len(mot) > 0:\n            x = torch.cat([x, mot[0]], dim=1)\n            rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)\n            mask_input = torch.cat(\n                [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1\n            )\n        return x, rope_embs, mask_input\n\n    def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):\n        if block_idx in self.audio_injector.injected_block_id.keys():\n            audio_attn_id = self.audio_injector.injected_block_id[block_idx]\n            num_frames = audio_emb.shape[1]\n            if use_unified_sequence_parallel:\n                from xfuser.core.distributed import get_sp_group\n                hidden_states = get_sp_group().all_gather(hidden_states, dim=1)\n\n            input_hidden_states = hidden_states[:, :original_seq_len].clone()  # b (f h w) c\n            input_hidden_states = rearrange(input_hidden_states, \"b (t n) c -> (b t) n c\", t=num_frames)\n\n            audio_emb_global = rearrange(audio_emb_global, \"b t n c -> (b t) n c\")\n            adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])\n            attn_hidden_states = adain_hidden_states\n\n            audio_emb = rearrange(audio_emb, \"b t n c -> (b t) n c\", t=num_frames)\n            attn_audio_emb = audio_emb\n            residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)\n            residual_out = rearrange(residual_out, \"(b t) n c -> b (t n) c\", t=num_frames)\n            hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out\n            if use_unified_sequence_parallel:\n                from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank\n                hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]\n        return hidden_states\n\n    def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):\n        audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)\n        audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)\n        audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()\n        merged_audio_emb = audio_emb[:, motion_frames[1]:, :]\n        return audio_emb_global, merged_audio_emb\n\n    def get_grid_sizes(self, grid_size_x, grid_size_ref):\n        f, h, w = grid_size_x\n        rf, rh, rw = grid_size_ref\n        grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)\n        grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]\n        grid_sizes_ref = [[\n            torch.tensor([30, 0, 0]).unsqueeze(0),\n            torch.tensor([31, rh, rw]).unsqueeze(0),\n            torch.tensor([1, rh, rw]).unsqueeze(0),\n        ]]\n        return grid_sizes_x + grid_sizes_ref\n\n    def forward(\n        self,\n        latents,\n        timestep,\n        context,\n        audio_input,\n        motion_latents,\n        pose_cond,\n        use_gradient_checkpointing_offload=False,\n        use_gradient_checkpointing=False\n    ):\n        origin_ref_latents = latents[:, :, 0:1]\n        x = latents[:, :, 1:]\n\n        # context embedding\n        context = self.text_embedding(context)\n\n        # audio encode\n        audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)\n\n        # x and pose_cond\n        pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond\n        x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond))  # torch.Size([1, 29120, 5120])\n        seq_len_x = x.shape[1]\n\n        # reference image\n        ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents))  # torch.Size([1, 1456, 5120])\n        grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))\n        x = torch.cat([x, ref_latents], dim=1)\n        # mask\n        mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)\n        # freqs\n        pre_compute_freqs = rope_precompute(\n            x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None\n        )\n        # motion\n        x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)\n\n        x = x + self.trainable_cond_mask(mask).to(x.dtype)\n\n        # t_mod\n        timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])\n        t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))\n        t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)\n\n        for block_id, block in enumerate(self.blocks):\n            x = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                x, context, t_mod, seq_len_x, pre_compute_freqs[0]\n            )\n            x = gradient_checkpoint_forward(\n                lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                x\n            )\n\n        x = x[:, :seq_len_x]\n        x = self.head(x, t[:-1])\n        x = self.unpatchify(x, (f, h, w))\n        # make compatible with wan video\n        x = torch.cat([origin_ref_latents, x], dim=2)\n        return x\n"
  },
  {
    "path": "diffsynth/models/wan_video_image_encoder.py",
    "content": "\"\"\"\nConcise re-implementation of\n``https://github.com/openai/CLIP'' and\n``https://github.com/mlfoundations/open_clip''.\n\"\"\"\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as T\nfrom .wan_video_dit import flash_attention\n\n\nclass SelfAttention(nn.Module):\n\n    def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.eps = eps\n\n        # layers\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x, mask):\n        \"\"\"\n        x:   [B, L, C].\n        \"\"\"\n        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n        k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n        v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n\n        # compute attention\n        p = self.dropout.p if self.training else 0.0\n        x = F.scaled_dot_product_attention(q, k, v, mask, p)\n        x = x.permute(0, 2, 1, 3).reshape(b, s, c)\n\n        # output\n        x = self.o(x)\n        x = self.dropout(x)\n        return x\n\n\nclass AttentionBlock(nn.Module):\n\n    def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.post_norm = post_norm\n        self.eps = eps\n\n        # layers\n        self.attn = SelfAttention(dim, num_heads, dropout, eps)\n        self.norm1 = nn.LayerNorm(dim, eps=eps)\n        self.ffn = nn.Sequential(\n            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),\n            nn.Dropout(dropout))\n        self.norm2 = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, x, mask):\n        if self.post_norm:\n            x = self.norm1(x + self.attn(x, mask))\n            x = self.norm2(x + self.ffn(x))\n        else:\n            x = x + self.attn(self.norm1(x), mask)\n            x = x + self.ffn(self.norm2(x))\n        return x\n\n\nclass XLMRoberta(nn.Module):\n    \"\"\"\n    XLMRobertaModel with no pooler and no LM head.\n    \"\"\"\n\n    def __init__(self,\n                 vocab_size=250002,\n                 max_seq_len=514,\n                 type_size=1,\n                 pad_id=1,\n                 dim=1024,\n                 num_heads=16,\n                 num_layers=24,\n                 post_norm=True,\n                 dropout=0.1,\n                 eps=1e-5):\n        super().__init__()\n        self.vocab_size = vocab_size\n        self.max_seq_len = max_seq_len\n        self.type_size = type_size\n        self.pad_id = pad_id\n        self.dim = dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.post_norm = post_norm\n        self.eps = eps\n\n        # embeddings\n        self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)\n        self.type_embedding = nn.Embedding(type_size, dim)\n        self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)\n        self.dropout = nn.Dropout(dropout)\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            AttentionBlock(dim, num_heads, post_norm, dropout, eps)\n            for _ in range(num_layers)\n        ])\n\n        # norm layer\n        self.norm = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, ids):\n        \"\"\"\n        ids: [B, L] of torch.LongTensor.\n        \"\"\"\n        b, s = ids.shape\n        mask = ids.ne(self.pad_id).long()\n\n        # embeddings\n        x = self.token_embedding(ids) + \\\n            self.type_embedding(torch.zeros_like(ids)) + \\\n            self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)\n        if self.post_norm:\n            x = self.norm(x)\n        x = self.dropout(x)\n\n        # blocks\n        mask = torch.where(\n            mask.view(b, 1, 1, s).gt(0), 0.0,\n            torch.finfo(x.dtype).min)\n        for block in self.blocks:\n            x = block(x, mask)\n\n        # output\n        if not self.post_norm:\n            x = self.norm(x)\n        return x\n\n\ndef xlm_roberta_large(pretrained=False,\n                      return_tokenizer=False,\n                      device='cpu',\n                      **kwargs):\n    \"\"\"\n    XLMRobertaLarge adapted from Huggingface.\n    \"\"\"\n    # params\n    cfg = dict(\n        vocab_size=250002,\n        max_seq_len=514,\n        type_size=1,\n        pad_id=1,\n        dim=1024,\n        num_heads=16,\n        num_layers=24,\n        post_norm=True,\n        dropout=0.1,\n        eps=1e-5)\n    cfg.update(**kwargs)\n\n    # init model\n    if pretrained:\n        from sora import DOWNLOAD_TO_CACHE\n\n        # init a meta model\n        with torch.device('meta'):\n            model = XLMRoberta(**cfg)\n\n        # load checkpoint\n        model.load_state_dict(\n            torch.load(\n                DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),\n                map_location=device),\n            assign=True)\n    else:\n        # init a model on device\n        with torch.device(device):\n            model = XLMRoberta(**cfg)\n\n    # init tokenizer\n    if return_tokenizer:\n        from sora.data import HuggingfaceTokenizer\n        tokenizer = HuggingfaceTokenizer(\n            name='xlm-roberta-large',\n            seq_len=model.text_len,\n            clean='whitespace')\n        return model, tokenizer\n    else:\n        return model\n\n\n\ndef pos_interpolate(pos, seq_len):\n    if pos.size(1) == seq_len:\n        return pos\n    else:\n        src_grid = int(math.sqrt(pos.size(1)))\n        tar_grid = int(math.sqrt(seq_len))\n        n = pos.size(1) - src_grid * src_grid\n        return torch.cat([\n            pos[:, :n],\n            F.interpolate(\n                pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(\n                    0, 3, 1, 2),\n                size=(tar_grid, tar_grid),\n                mode='bicubic',\n                align_corners=False).flatten(2).transpose(1, 2)\n        ],\n                         dim=1)\n\n\nclass QuickGELU(nn.Module):\n\n    def forward(self, x):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass LayerNorm(nn.LayerNorm):\n\n    def forward(self, x):\n        return super().forward(x).type_as(x)\n\n\nclass SelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 num_heads,\n                 causal=False,\n                 attn_dropout=0.0,\n                 proj_dropout=0.0):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.causal = causal\n        self.attn_dropout = attn_dropout\n        self.proj_dropout = proj_dropout\n\n        # layers\n        self.to_qkv = nn.Linear(dim, dim * 3)\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(self, x):\n        \"\"\"\n        x:   [B, L, C].\n        \"\"\"\n        # compute query, key, value\n        q, k, v = self.to_qkv(x).chunk(3, dim=-1)\n\n        # compute attention\n        x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)\n\n        # output\n        x = self.proj(x)\n        x = F.dropout(x, self.proj_dropout, self.training)\n        return x\n\n\nclass SwiGLU(nn.Module):\n\n    def __init__(self, dim, mid_dim):\n        super().__init__()\n        self.dim = dim\n        self.mid_dim = mid_dim\n\n        # layers\n        self.fc1 = nn.Linear(dim, mid_dim)\n        self.fc2 = nn.Linear(dim, mid_dim)\n        self.fc3 = nn.Linear(mid_dim, dim)\n\n    def forward(self, x):\n        x = F.silu(self.fc1(x)) * self.fc2(x)\n        x = self.fc3(x)\n        return x\n\n\nclass AttentionBlock(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 mlp_ratio,\n                 num_heads,\n                 post_norm=False,\n                 causal=False,\n                 activation='quick_gelu',\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 norm_eps=1e-5):\n        assert activation in ['quick_gelu', 'gelu', 'swi_glu']\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.num_heads = num_heads\n        self.post_norm = post_norm\n        self.causal = causal\n        self.norm_eps = norm_eps\n\n        # layers\n        self.norm1 = LayerNorm(dim, eps=norm_eps)\n        self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,\n                                  proj_dropout)\n        self.norm2 = LayerNorm(dim, eps=norm_eps)\n        if activation == 'swi_glu':\n            self.mlp = SwiGLU(dim, int(dim * mlp_ratio))\n        else:\n            self.mlp = nn.Sequential(\n                nn.Linear(dim, int(dim * mlp_ratio)),\n                QuickGELU() if activation == 'quick_gelu' else nn.GELU(),\n                nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))\n\n    def forward(self, x):\n        if self.post_norm:\n            x = x + self.norm1(self.attn(x))\n            x = x + self.norm2(self.mlp(x))\n        else:\n            x = x + self.attn(self.norm1(x))\n            x = x + self.mlp(self.norm2(x))\n        return x\n\n\nclass AttentionPool(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 mlp_ratio,\n                 num_heads,\n                 activation='gelu',\n                 proj_dropout=0.0,\n                 norm_eps=1e-5):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.proj_dropout = proj_dropout\n        self.norm_eps = norm_eps\n\n        # layers\n        gain = 1.0 / math.sqrt(dim)\n        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))\n        self.to_q = nn.Linear(dim, dim)\n        self.to_kv = nn.Linear(dim, dim * 2)\n        self.proj = nn.Linear(dim, dim)\n        self.norm = LayerNorm(dim, eps=norm_eps)\n        self.mlp = nn.Sequential(\n            nn.Linear(dim, int(dim * mlp_ratio)),\n            QuickGELU() if activation == 'quick_gelu' else nn.GELU(),\n            nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))\n\n    def forward(self, x):\n        \"\"\"\n        x:  [B, L, C].\n        \"\"\"\n        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)\n        k, v = self.to_kv(x).chunk(2, dim=-1)\n\n        # compute attention\n        x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)\n        x = x.reshape(b, 1, c)\n\n        # output\n        x = self.proj(x)\n        x = F.dropout(x, self.proj_dropout, self.training)\n\n        # mlp\n        x = x + self.mlp(self.norm(x))\n        return x[:, 0]\n\n\nclass VisionTransformer(nn.Module):\n\n    def __init__(self,\n                 image_size=224,\n                 patch_size=16,\n                 dim=768,\n                 mlp_ratio=4,\n                 out_dim=512,\n                 num_heads=12,\n                 num_layers=12,\n                 pool_type='token',\n                 pre_norm=True,\n                 post_norm=False,\n                 activation='quick_gelu',\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 embedding_dropout=0.0,\n                 norm_eps=1e-5):\n        if image_size % patch_size != 0:\n            print(\n                '[WARNING] image_size is not divisible by patch_size',\n                flush=True)\n        assert pool_type in ('token', 'token_fc', 'attn_pool')\n        out_dim = out_dim or dim\n        super().__init__()\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.num_patches = (image_size // patch_size)**2\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.pool_type = pool_type\n        self.post_norm = post_norm\n        self.norm_eps = norm_eps\n\n        # embeddings\n        gain = 1.0 / math.sqrt(dim)\n        self.patch_embedding = nn.Conv2d(\n            3,\n            dim,\n            kernel_size=patch_size,\n            stride=patch_size,\n            bias=not pre_norm)\n        if pool_type in ('token', 'token_fc'):\n            self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))\n        self.pos_embedding = nn.Parameter(gain * torch.randn(\n            1, self.num_patches +\n            (1 if pool_type in ('token', 'token_fc') else 0), dim))\n        self.dropout = nn.Dropout(embedding_dropout)\n\n        # transformer\n        self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None\n        self.transformer = nn.Sequential(*[\n            AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,\n                           activation, attn_dropout, proj_dropout, norm_eps)\n            for _ in range(num_layers)\n        ])\n        self.post_norm = LayerNorm(dim, eps=norm_eps)\n\n        # head\n        if pool_type == 'token':\n            self.head = nn.Parameter(gain * torch.randn(dim, out_dim))\n        elif pool_type == 'token_fc':\n            self.head = nn.Linear(dim, out_dim)\n        elif pool_type == 'attn_pool':\n            self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,\n                                      proj_dropout, norm_eps)\n\n    def forward(self, x, interpolation=False, use_31_block=False):\n        b = x.size(0)\n\n        # embeddings\n        x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)\n        if self.pool_type in ('token', 'token_fc'):\n            x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)\n        if interpolation:\n            e = pos_interpolate(self.pos_embedding, x.size(1))\n        else:\n            e = self.pos_embedding\n        e = e.to(dtype=x.dtype, device=x.device)\n        x = self.dropout(x + e)\n        if self.pre_norm is not None:\n            x = self.pre_norm(x)\n\n        # transformer\n        if use_31_block:\n            x = self.transformer[:-1](x)\n            return x\n        else:\n            x = self.transformer(x)\n            return x\n\n\nclass CLIP(nn.Module):\n\n    def __init__(self,\n                 embed_dim=512,\n                 image_size=224,\n                 patch_size=16,\n                 vision_dim=768,\n                 vision_mlp_ratio=4,\n                 vision_heads=12,\n                 vision_layers=12,\n                 vision_pool='token',\n                 vision_pre_norm=True,\n                 vision_post_norm=False,\n                 vocab_size=49408,\n                 text_len=77,\n                 text_dim=512,\n                 text_mlp_ratio=4,\n                 text_heads=8,\n                 text_layers=12,\n                 text_causal=True,\n                 text_pool='argmax',\n                 text_head_bias=False,\n                 logit_bias=None,\n                 activation='quick_gelu',\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 embedding_dropout=0.0,\n                 norm_eps=1e-5):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.vision_dim = vision_dim\n        self.vision_mlp_ratio = vision_mlp_ratio\n        self.vision_heads = vision_heads\n        self.vision_layers = vision_layers\n        self.vision_pool = vision_pool\n        self.vision_pre_norm = vision_pre_norm\n        self.vision_post_norm = vision_post_norm\n        self.vocab_size = vocab_size\n        self.text_len = text_len\n        self.text_dim = text_dim\n        self.text_mlp_ratio = text_mlp_ratio\n        self.text_heads = text_heads\n        self.text_layers = text_layers\n        self.text_causal = text_causal\n        self.text_pool = text_pool\n        self.text_head_bias = text_head_bias\n        self.norm_eps = norm_eps\n\n        # models\n        self.visual = VisionTransformer(\n            image_size=image_size,\n            patch_size=patch_size,\n            dim=vision_dim,\n            mlp_ratio=vision_mlp_ratio,\n            out_dim=embed_dim,\n            num_heads=vision_heads,\n            num_layers=vision_layers,\n            pool_type=vision_pool,\n            pre_norm=vision_pre_norm,\n            post_norm=vision_post_norm,\n            activation=activation,\n            attn_dropout=attn_dropout,\n            proj_dropout=proj_dropout,\n            embedding_dropout=embedding_dropout,\n            norm_eps=norm_eps)\n        self.textual = TextTransformer(\n            vocab_size=vocab_size,\n            text_len=text_len,\n            dim=text_dim,\n            mlp_ratio=text_mlp_ratio,\n            out_dim=embed_dim,\n            num_heads=text_heads,\n            num_layers=text_layers,\n            causal=text_causal,\n            pool_type=text_pool,\n            head_bias=text_head_bias,\n            activation=activation,\n            attn_dropout=attn_dropout,\n            proj_dropout=proj_dropout,\n            embedding_dropout=embedding_dropout,\n            norm_eps=norm_eps)\n        self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))\n        if logit_bias is not None:\n            self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))\n\n        # initialize weights\n        self.init_weights()\n\n    def forward(self, imgs, txt_ids):\n        \"\"\"\n        imgs:       [B, 3, H, W] of torch.float32.\n        - mean:     [0.48145466, 0.4578275, 0.40821073]\n        - std:      [0.26862954, 0.26130258, 0.27577711]\n        txt_ids:    [B, L] of torch.long. Encoded by data.CLIPTokenizer.\n        \"\"\"\n        xi = self.visual(imgs)\n        xt = self.textual(txt_ids)\n        return xi, xt\n\n    def init_weights(self):\n        # embeddings\n        nn.init.normal_(self.textual.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)\n\n        # attentions\n        for modality in ['visual', 'textual']:\n            dim = self.vision_dim if modality == 'visual' else self.text_dim\n            transformer = getattr(self, modality).transformer\n            proj_gain = (1.0 / math.sqrt(dim)) * (\n                1.0 / math.sqrt(2 * len(transformer)))\n            attn_gain = 1.0 / math.sqrt(dim)\n            mlp_gain = 1.0 / math.sqrt(2.0 * dim)\n            for block in transformer:\n                nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)\n                nn.init.normal_(block.attn.proj.weight, std=proj_gain)\n                nn.init.normal_(block.mlp[0].weight, std=mlp_gain)\n                nn.init.normal_(block.mlp[2].weight, std=proj_gain)\n\n    def param_groups(self):\n        groups = [{\n            'params': [\n                p for n, p in self.named_parameters()\n                if 'norm' in n or n.endswith('bias')\n            ],\n            'weight_decay': 0.0\n        }, {\n            'params': [\n                p for n, p in self.named_parameters()\n                if not ('norm' in n or n.endswith('bias'))\n            ]\n        }]\n        return groups\n\n\nclass XLMRobertaWithHead(XLMRoberta):\n\n    def __init__(self, **kwargs):\n        self.out_dim = kwargs.pop('out_dim')\n        super().__init__(**kwargs)\n\n        # head\n        mid_dim = (self.dim + self.out_dim) // 2\n        self.head = nn.Sequential(\n            nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),\n            nn.Linear(mid_dim, self.out_dim, bias=False))\n\n    def forward(self, ids):\n        # xlm-roberta\n        x = super().forward(ids)\n\n        # average pooling\n        mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)\n        x = (x * mask).sum(dim=1) / mask.sum(dim=1)\n\n        # head\n        x = self.head(x)\n        return x\n\n\nclass XLMRobertaCLIP(nn.Module):\n\n    def __init__(self,\n                 embed_dim=1024,\n                 image_size=224,\n                 patch_size=14,\n                 vision_dim=1280,\n                 vision_mlp_ratio=4,\n                 vision_heads=16,\n                 vision_layers=32,\n                 vision_pool='token',\n                 vision_pre_norm=True,\n                 vision_post_norm=False,\n                 activation='gelu',\n                 vocab_size=250002,\n                 max_text_len=514,\n                 type_size=1,\n                 pad_id=1,\n                 text_dim=1024,\n                 text_heads=16,\n                 text_layers=24,\n                 text_post_norm=True,\n                 text_dropout=0.1,\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 embedding_dropout=0.0,\n                 norm_eps=1e-5):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.vision_dim = vision_dim\n        self.vision_mlp_ratio = vision_mlp_ratio\n        self.vision_heads = vision_heads\n        self.vision_layers = vision_layers\n        self.vision_pre_norm = vision_pre_norm\n        self.vision_post_norm = vision_post_norm\n        self.activation = activation\n        self.vocab_size = vocab_size\n        self.max_text_len = max_text_len\n        self.type_size = type_size\n        self.pad_id = pad_id\n        self.text_dim = text_dim\n        self.text_heads = text_heads\n        self.text_layers = text_layers\n        self.text_post_norm = text_post_norm\n        self.norm_eps = norm_eps\n\n        # models\n        self.visual = VisionTransformer(\n            image_size=image_size,\n            patch_size=patch_size,\n            dim=vision_dim,\n            mlp_ratio=vision_mlp_ratio,\n            out_dim=embed_dim,\n            num_heads=vision_heads,\n            num_layers=vision_layers,\n            pool_type=vision_pool,\n            pre_norm=vision_pre_norm,\n            post_norm=vision_post_norm,\n            activation=activation,\n            attn_dropout=attn_dropout,\n            proj_dropout=proj_dropout,\n            embedding_dropout=embedding_dropout,\n            norm_eps=norm_eps)\n        self.textual = None\n        self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))\n\n    def forward(self, imgs, txt_ids):\n        \"\"\"\n        imgs:       [B, 3, H, W] of torch.float32.\n        - mean:     [0.48145466, 0.4578275, 0.40821073]\n        - std:      [0.26862954, 0.26130258, 0.27577711]\n        txt_ids:    [B, L] of torch.long.\n                    Encoded by data.CLIPTokenizer.\n        \"\"\"\n        xi = self.visual(imgs)\n        xt = self.textual(txt_ids)\n        return xi, xt\n\n    def param_groups(self):\n        groups = [{\n            'params': [\n                p for n, p in self.named_parameters()\n                if 'norm' in n or n.endswith('bias')\n            ],\n            'weight_decay': 0.0\n        }, {\n            'params': [\n                p for n, p in self.named_parameters()\n                if not ('norm' in n or n.endswith('bias'))\n            ]\n        }]\n        return groups\n\n\ndef _clip(pretrained=False,\n          pretrained_name=None,\n          model_cls=CLIP,\n          return_transforms=False,\n          return_tokenizer=False,\n          tokenizer_padding='eos',\n          dtype=torch.float32,\n          device='cpu',\n          **kwargs):\n    # init model\n    if pretrained and pretrained_name:\n        from sora import BUCKET, DOWNLOAD_TO_CACHE\n\n        # init a meta model\n        with torch.device('meta'):\n            model = model_cls(**kwargs)\n\n        # checkpoint path\n        checkpoint = f'models/clip/{pretrained_name}'\n        if dtype in (torch.float16, torch.bfloat16):\n            suffix = '-' + {\n                torch.float16: 'fp16',\n                torch.bfloat16: 'bf16'\n            }[dtype]\n            if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):\n                checkpoint = f'{checkpoint}{suffix}'\n        checkpoint += '.pth'\n\n        # load\n        model.load_state_dict(\n            torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),\n            assign=True,\n            strict=False)\n    else:\n        # init a model on device\n        with torch.device(device):\n            model = model_cls(**kwargs)\n\n    # set device\n    output = (model,)\n\n    # init transforms\n    if return_transforms:\n        # mean and std\n        if 'siglip' in pretrained_name.lower():\n            mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]\n        else:\n            mean = [0.48145466, 0.4578275, 0.40821073]\n            std = [0.26862954, 0.26130258, 0.27577711]\n\n        # transforms\n        transforms = T.Compose([\n            T.Resize((model.image_size, model.image_size),\n                     interpolation=T.InterpolationMode.BICUBIC),\n            T.ToTensor(),\n            T.Normalize(mean=mean, std=std)\n        ])\n        output += (transforms,)\n\n    # init tokenizer\n    if return_tokenizer:\n        from sora import data\n        if 'siglip' in pretrained_name.lower():\n            tokenizer = data.HuggingfaceTokenizer(\n                name=f'timm/{pretrained_name}',\n                seq_len=model.text_len,\n                clean='canonicalize')\n        elif 'xlm' in pretrained_name.lower():\n            tokenizer = data.HuggingfaceTokenizer(\n                name='xlm-roberta-large',\n                seq_len=model.max_text_len - 2,\n                clean='whitespace')\n        elif 'mba' in pretrained_name.lower():\n            tokenizer = data.HuggingfaceTokenizer(\n                name='facebook/xlm-roberta-xl',\n                seq_len=model.max_text_len - 2,\n                clean='whitespace')\n        else:\n            tokenizer = data.CLIPTokenizer(\n                seq_len=model.text_len, padding=tokenizer_padding)\n        output += (tokenizer,)\n    return output[0] if len(output) == 1 else output\n\n\ndef clip_xlm_roberta_vit_h_14(\n        pretrained=False,\n        pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',\n        **kwargs):\n    cfg = dict(\n        embed_dim=1024,\n        image_size=224,\n        patch_size=14,\n        vision_dim=1280,\n        vision_mlp_ratio=4,\n        vision_heads=16,\n        vision_layers=32,\n        vision_pool='token',\n        activation='gelu',\n        vocab_size=250002,\n        max_text_len=514,\n        type_size=1,\n        pad_id=1,\n        text_dim=1024,\n        text_heads=16,\n        text_layers=24,\n        text_post_norm=True,\n        text_dropout=0.1,\n        attn_dropout=0.0,\n        proj_dropout=0.0,\n        embedding_dropout=0.0)\n    cfg.update(**kwargs)\n    return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)\n\n\nclass WanImageEncoder(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        # init model\n        self.model, self.transforms = clip_xlm_roberta_vit_h_14(\n            pretrained=False,\n            return_transforms=True,\n            return_tokenizer=False,\n            dtype=torch.float32,\n            device=\"cpu\")\n\n    def encode_image(self, videos):\n        # preprocess\n        size = (self.model.image_size,) * 2\n        videos = torch.cat([\n            F.interpolate(\n                u,\n                size=size,\n                mode='bicubic',\n                align_corners=False) for u in videos\n        ])\n        videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))\n\n        # forward\n        out = self.model.visual(videos, use_31_block=True)\n        return out\n"
  },
  {
    "path": "diffsynth/models/wan_video_mot.py",
    "content": "import torch\nfrom .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP\nimport einops\nimport torch.nn as nn\n\n\nclass MotSelfAttention(SelfAttention):\n    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__(dim, num_heads, eps)\n    def forward(self, x, freqs, is_before_attn=False):\n        if is_before_attn:\n            q = self.norm_q(self.q(x))\n            k = self.norm_k(self.k(x))\n            v = self.v(x)\n            q = rope_apply(q, freqs, self.num_heads)\n            k = rope_apply(k, freqs, self.num_heads)\n            return q, k, v\n        else:\n            return self.o(x)\n\n\nclass MotWanAttentionBlock(DiTBlock):\n    def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):\n        super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)\n        self.block_id = block_id\n\n        self.self_attn = MotSelfAttention(dim, num_heads, eps)\n\n\n    def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):\n\n        # 1. prepare scale parameter\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)\n        \n        scale_params_mot_ref = self.modulation + t_mod_mot.float()\n        scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)\n        shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)\n\n        # 2. Self-attention\n        input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)\n        # original block self-attn\n        attn1 = wan_block.self_attn\n        q = attn1.norm_q(attn1.q(input_x))\n        k = attn1.norm_k(attn1.k(input_x))\n        v = attn1.v(input_x)\n        q = rope_apply(q, freqs, attn1.num_heads)\n        k = rope_apply(k, freqs, attn1.num_heads)\n\n        # mot block self-attn\n        norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)\n        norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)\n        norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)\n        q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)\n\n        tmp_hidden_states = flash_attention(\n            torch.cat([q, q_mot], dim=-2),\n            torch.cat([k, k_mot], dim=-2),\n            torch.cat([v, v_mot], dim=-2),\n            num_heads=attn1.num_heads)\n\n        attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)\n        \n        attn_output = attn1.o(attn_output)\n        x = wan_block.gate(x, gate_msa, attn_output)\n\n        attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)\n        # gate\n        attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)\n        attn_output_mot = attn_output_mot * gate_msa_mot_ref\n        attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)\n        x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)\n\n        # 3. cross-attention and feed-forward\n        x = x + wan_block.cross_attn(wan_block.norm3(x), context)\n        input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)\n        x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))\n\n        x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)\n        # modulate\n        norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)\n        norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)\n        norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)\n        input_x_mot = self.ffn(norm_x_mot_ref)\n        # gate\n        input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)\n        input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref\n        input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)\n        x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)\n\n        return x, x_mot\n\n\nclass MotWanModel(torch.nn.Module):\n    def __init__(\n        self,\n        mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),\n        patch_size=(1, 2, 2),\n        has_image_input=True,\n        has_image_pos_emb=False,\n        dim=5120,\n        num_heads=40,\n        ffn_dim=13824,\n        freq_dim=256,\n        text_dim=4096,\n        in_dim=36,\n        eps=1e-6,\n    ):\n        super().__init__()\n        self.mot_layers = mot_layers\n        self.freq_dim = freq_dim\n        self.dim = dim\n\n        self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}\n        self.head_dim = dim // num_heads\n\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n\n        self.text_embedding = nn.Sequential(\n            nn.Linear(text_dim, dim),\n            nn.GELU(approximate='tanh'),\n            nn.Linear(dim, dim)\n        )\n        self.time_embedding = nn.Sequential(\n            nn.Linear(freq_dim, dim),\n            nn.SiLU(),\n            nn.Linear(dim, dim)\n        )\n        self.time_projection = nn.Sequential(\n            nn.SiLU(), nn.Linear(dim, dim * 6))\n        if has_image_input:\n            self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)\n\n        # mot blocks\n        self.blocks = torch.nn.ModuleList([\n            MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)\n            for i in self.mot_layers\n        ])\n    \n\n    def patchify(self, x: torch.Tensor):\n        x = self.patch_embedding(x)\n        return x\n\n    def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):\n        def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):\n            # 1d rope precompute\n            freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)\n                        [: (dim // 2)].double() / dim))\n            freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)\n            freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n            return freqs_cis\n\n        f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)\n        h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)\n        w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)\n\n        freqs = torch.cat([\n            f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n            h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),\n            w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)\n        ], dim=-1).reshape(f * h * w, 1, -1)\n        return freqs\n\n    def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):\n        block = self.blocks[self.mot_layers_mapping[block_id]]\n        x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)\n        return x, x_mot\n"
  },
  {
    "path": "diffsynth/models/wan_video_motion_controller.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .wan_video_dit import sinusoidal_embedding_1d\n\n\n\nclass WanMotionControllerModel(torch.nn.Module):\n    def __init__(self, freq_dim=256, dim=1536):\n        super().__init__()\n        self.freq_dim = freq_dim\n        self.linear = nn.Sequential(\n            nn.Linear(freq_dim, dim),\n            nn.SiLU(),\n            nn.Linear(dim, dim),\n            nn.SiLU(),\n            nn.Linear(dim, dim * 6),\n        )\n\n    def forward(self, motion_bucket_id):\n        emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)\n        emb = self.linear(emb)\n        return emb\n\n    def init(self):\n        state_dict = self.linear[-1].state_dict()\n        state_dict = {i: state_dict[i] * 0 for i in state_dict}\n        self.linear[-1].load_state_dict(state_dict)\n"
  },
  {
    "path": "diffsynth/models/wan_video_text_encoder.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import AutoTokenizer\nimport ftfy\nimport html\nimport string\nimport regex as re\n\ndef fp16_clamp(x):\n    if x.dtype == torch.float16 and torch.isinf(x).any():\n        clamp = torch.finfo(x.dtype).max - 1000\n        x = torch.clamp(x, min=-clamp, max=clamp)\n    return x\n\n\nclass GELU(nn.Module):\n\n    def forward(self, x):\n        return 0.5 * x * (1.0 + torch.tanh(\n            math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))\n\n\nclass T5LayerNorm(nn.Module):\n\n    def __init__(self, dim, eps=1e-6):\n        super(T5LayerNorm, self).__init__()\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +\n                            self.eps)\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            x = x.type_as(self.weight)\n        return self.weight * x\n\n\nclass T5Attention(nn.Module):\n\n    def __init__(self, dim, dim_attn, num_heads, dropout=0.1):\n        assert dim_attn % num_heads == 0\n        super(T5Attention, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.num_heads = num_heads\n        self.head_dim = dim_attn // num_heads\n\n        # layers\n        self.q = nn.Linear(dim, dim_attn, bias=False)\n        self.k = nn.Linear(dim, dim_attn, bias=False)\n        self.v = nn.Linear(dim, dim_attn, bias=False)\n        self.o = nn.Linear(dim_attn, dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x, context=None, mask=None, pos_bias=None):\n        \"\"\"\n        x:          [B, L1, C].\n        context:    [B, L2, C] or None.\n        mask:       [B, L2] or [B, L1, L2] or None.\n        \"\"\"\n        # check inputs\n        context = x if context is None else context\n        b, n, c = x.size(0), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.q(x).view(b, -1, n, c)\n        k = self.k(context).view(b, -1, n, c)\n        v = self.v(context).view(b, -1, n, c)\n\n        # attention bias\n        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))\n        if pos_bias is not None:\n            attn_bias += pos_bias\n        if mask is not None:\n            assert mask.ndim in [2, 3]\n            mask = mask.view(b, 1, 1,\n                             -1) if mask.ndim == 2 else mask.unsqueeze(1)\n            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)\n\n        # compute attention (T5 does not use scaling)\n        attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias\n        attn = F.softmax(attn.float(), dim=-1).type_as(attn)\n        x = torch.einsum('bnij,bjnc->binc', attn, v)\n\n        # output\n        x = x.reshape(b, -1, n * c)\n        x = self.o(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5FeedForward(nn.Module):\n\n    def __init__(self, dim, dim_ffn, dropout=0.1):\n        super(T5FeedForward, self).__init__()\n        self.dim = dim\n        self.dim_ffn = dim_ffn\n\n        # layers\n        self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())\n        self.fc1 = nn.Linear(dim, dim_ffn, bias=False)\n        self.fc2 = nn.Linear(dim_ffn, dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = self.fc1(x) * self.gate(x)\n        x = self.dropout(x)\n        x = self.fc2(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5SelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5SelfAttention, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.norm1 = T5LayerNorm(dim)\n        self.attn = T5Attention(dim, dim_attn, num_heads, dropout)\n        self.norm2 = T5LayerNorm(dim)\n        self.ffn = T5FeedForward(dim, dim_ffn, dropout)\n        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=True)\n\n    def forward(self, x, mask=None, pos_bias=None):\n        e = pos_bias if self.shared_pos else self.pos_embedding(\n            x.size(1), x.size(1))\n        x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))\n        x = fp16_clamp(x + self.ffn(self.norm2(x)))\n        return x\n\n\nclass T5RelativeEmbedding(nn.Module):\n\n    def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):\n        super(T5RelativeEmbedding, self).__init__()\n        self.num_buckets = num_buckets\n        self.num_heads = num_heads\n        self.bidirectional = bidirectional\n        self.max_dist = max_dist\n\n        # layers\n        self.embedding = nn.Embedding(num_buckets, num_heads)\n\n    def forward(self, lq, lk):\n        device = self.embedding.weight.device\n        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \\\n        #     torch.arange(lq).unsqueeze(1).to(device)\n        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \\\n            torch.arange(lq, device=device).unsqueeze(1)\n        rel_pos = self._relative_position_bucket(rel_pos)\n        rel_pos_embeds = self.embedding(rel_pos)\n        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(\n            0)  # [1, N, Lq, Lk]\n        return rel_pos_embeds.contiguous()\n\n    def _relative_position_bucket(self, rel_pos):\n        # preprocess\n        if self.bidirectional:\n            num_buckets = self.num_buckets // 2\n            rel_buckets = (rel_pos > 0).long() * num_buckets\n            rel_pos = torch.abs(rel_pos)\n        else:\n            num_buckets = self.num_buckets\n            rel_buckets = 0\n            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))\n\n        # embeddings for small and large positions\n        max_exact = num_buckets // 2\n        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /\n                                     math.log(self.max_dist / max_exact) *\n                                     (num_buckets - max_exact)).long()\n        rel_pos_large = torch.min(\n            rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))\n        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)\n        return rel_buckets\n\ndef init_weights(m):\n    if isinstance(m, T5LayerNorm):\n        nn.init.ones_(m.weight)\n    elif isinstance(m, T5FeedForward):\n        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)\n        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)\n    elif isinstance(m, T5Attention):\n        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)\n        nn.init.normal_(m.k.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.v.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)\n    elif isinstance(m, T5RelativeEmbedding):\n        nn.init.normal_(\n            m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)\n\n\nclass WanTextEncoder(torch.nn.Module):\n\n    def __init__(self,\n                 vocab=256384,\n                 dim=4096,\n                 dim_attn=4096,\n                 dim_ffn=10240,\n                 num_heads=64,\n                 num_layers=24,\n                 num_buckets=32,\n                 shared_pos=False,\n                 dropout=0.1):\n        super(WanTextEncoder, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \\\n            else nn.Embedding(vocab, dim)\n        self.pos_embedding = T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=True) if shared_pos else None\n        self.dropout = nn.Dropout(dropout)\n        self.blocks = nn.ModuleList([\n            T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,\n                            shared_pos, dropout) for _ in range(num_layers)\n        ])\n        self.norm = T5LayerNorm(dim)\n\n        # initialize weights\n        self.apply(init_weights)\n\n    def forward(self, ids, mask=None):\n        x = self.token_embedding(ids)\n        x = self.dropout(x)\n        e = self.pos_embedding(x.size(1),\n                               x.size(1)) if self.shared_pos else None\n        for block in self.blocks:\n            x = block(x, mask, pos_bias=e)\n        x = self.norm(x)\n        x = self.dropout(x)\n        return x\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\ndef canonicalize(text, keep_punctuation_exact_string=None):\n    text = text.replace('_', ' ')\n    if keep_punctuation_exact_string:\n        text = keep_punctuation_exact_string.join(\n            part.translate(str.maketrans('', '', string.punctuation))\n            for part in text.split(keep_punctuation_exact_string))\n    else:\n        text = text.translate(str.maketrans('', '', string.punctuation))\n    text = text.lower()\n    text = re.sub(r'\\s+', ' ', text)\n    return text.strip()\n\n\nclass HuggingfaceTokenizer:\n\n    def __init__(self, name, seq_len=None, clean=None, **kwargs):\n        assert clean in (None, 'whitespace', 'lower', 'canonicalize')\n        self.name = name\n        self.seq_len = seq_len\n        self.clean = clean\n\n        # init tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)\n        self.vocab_size = self.tokenizer.vocab_size\n\n    def __call__(self, sequence, **kwargs):\n        return_mask = kwargs.pop('return_mask', False)\n\n        # arguments\n        _kwargs = {'return_tensors': 'pt'}\n        if self.seq_len is not None:\n            _kwargs.update({\n                'padding': 'max_length',\n                'truncation': True,\n                'max_length': self.seq_len\n            })\n        _kwargs.update(**kwargs)\n\n        # tokenization\n        if isinstance(sequence, str):\n            sequence = [sequence]\n        if self.clean:\n            sequence = [self._clean(u) for u in sequence]\n        ids = self.tokenizer(sequence, **_kwargs)\n\n        # output\n        if return_mask:\n            return ids.input_ids, ids.attention_mask\n        else:\n            return ids.input_ids\n    \n    def _clean(self, text):\n        if self.clean == 'whitespace':\n            text = whitespace_clean(basic_clean(text))\n        elif self.clean == 'lower':\n            text = whitespace_clean(basic_clean(text)).lower()\n        elif self.clean == 'canonicalize':\n            text = canonicalize(basic_clean(text))\n        return text"
  },
  {
    "path": "diffsynth/models/wan_video_vace.py",
    "content": "import torch\nfrom .wan_video_dit import DiTBlock\nfrom ..core.gradient import gradient_checkpoint_forward\n\nclass VaceWanAttentionBlock(DiTBlock):\n    def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):\n        super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)\n        self.block_id = block_id\n        if block_id == 0:\n            self.before_proj = torch.nn.Linear(self.dim, self.dim)\n        self.after_proj = torch.nn.Linear(self.dim, self.dim)\n\n    def forward(self, c, x, context, t_mod, freqs):\n        if self.block_id == 0:\n            c = self.before_proj(c) + x\n            all_c = []\n        else:\n            all_c = list(torch.unbind(c))\n            c = all_c.pop(-1)\n        c = super().forward(c, context, t_mod, freqs)\n        c_skip = self.after_proj(c)\n        all_c += [c_skip, c]\n        c = torch.stack(all_c)\n        return c\n\n\nclass VaceWanModel(torch.nn.Module):\n    def __init__(\n        self,\n        vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),\n        vace_in_dim=96,\n        patch_size=(1, 2, 2),\n        has_image_input=False,\n        dim=1536,\n        num_heads=12,\n        ffn_dim=8960,\n        eps=1e-6,\n    ):\n        super().__init__()\n        self.vace_layers = vace_layers\n        self.vace_in_dim = vace_in_dim\n        self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}\n\n        # vace blocks\n        self.vace_blocks = torch.nn.ModuleList([\n            VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)\n            for i in self.vace_layers\n        ])\n\n        # vace patch embeddings\n        self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(\n        self, x, vace_context, context, t_mod, freqs,\n        use_gradient_checkpointing: bool = False,\n        use_gradient_checkpointing_offload: bool = False,\n    ):\n        c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]\n        c = [u.flatten(2).transpose(1, 2) for u in c]\n        c = torch.cat([\n            torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],\n                      dim=1) for u in c\n        ])\n        \n        for block in self.vace_blocks:\n            c = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                c, x, context, t_mod, freqs\n            )\n            \n        hints = torch.unbind(c)[:-1]\n        return hints\n"
  },
  {
    "path": "diffsynth/models/wan_video_vae.py",
    "content": "from einops import rearrange, repeat\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nCACHE_T = 2\n\n\ndef check_is_instance(model, module_class):\n    if isinstance(model, module_class):\n        return True\n    if hasattr(model, \"module\") and isinstance(model.module, module_class):\n        return True\n    return False\n\n\ndef block_causal_mask(x, block_size):\n    # params\n    b, n, s, _, device = *x.size(), x.device\n    assert s % block_size == 0\n    num_blocks = s // block_size\n\n    # build mask\n    mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)\n    for i in range(num_blocks):\n        mask[:, :,\n             i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1\n    return mask\n\n\nclass CausalConv3d(nn.Conv3d):\n    \"\"\"\n    Causal 3d convolusion.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._padding = (self.padding[2], self.padding[2], self.padding[1],\n                         self.padding[1], 2 * self.padding[0], 0)\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = F.pad(x, padding)\n\n        return super().forward(x)\n\n\nclass RMS_norm(nn.Module):\n\n    def __init__(self, dim, channel_first=True, images=True, bias=False):\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.\n\n    def forward(self, x):\n        return F.normalize(\n            x, dim=(1 if self.channel_first else\n                    -1)) * self.scale * self.gamma + self.bias\n\n\nclass Upsample(nn.Upsample):\n\n    def forward(self, x):\n        \"\"\"\n        Fix bfloat16 support for nearest neighbor interpolation.\n        \"\"\"\n        return super().forward(x.float()).type_as(x)\n\n\nclass Resample(nn.Module):\n\n    def __init__(self, dim, mode):\n        assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',\n                        'downsample3d')\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == 'upsample2d':\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2., 2.), mode='nearest-exact'),\n                nn.Conv2d(dim, dim // 2, 3, padding=1))\n        elif mode == 'upsample3d':\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2., 2.), mode='nearest-exact'),\n                nn.Conv2d(dim, dim // 2, 3, padding=1))\n            self.time_conv = CausalConv3d(dim,\n                                          dim * 2, (3, 1, 1),\n                                          padding=(1, 0, 0))\n\n        elif mode == 'downsample2d':\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n        elif mode == 'downsample3d':\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n            self.time_conv = CausalConv3d(dim,\n                                          dim, (3, 1, 1),\n                                          stride=(2, 1, 1),\n                                          padding=(0, 0, 0))\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        b, c, t, h, w = x.size()\n        if self.mode == 'upsample3d':\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = 'Rep'\n                    feat_idx[0] += 1\n                else:\n\n                    cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                    if cache_x.shape[2] < 2 and feat_cache[\n                            idx] is not None and feat_cache[idx] != 'Rep':\n                        # cache last frame of last two chunk\n                        cache_x = torch.cat([\n                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                cache_x.device), cache_x\n                        ],\n                                            dim=2)\n                    if cache_x.shape[2] < 2 and feat_cache[\n                            idx] is not None and feat_cache[idx] == 'Rep':\n                        cache_x = torch.cat([\n                            torch.zeros_like(cache_x).to(cache_x.device),\n                            cache_x\n                        ],\n                                            dim=2)\n                    if feat_cache[idx] == 'Rep':\n                        x = self.time_conv(x)\n                    else:\n                        x = self.time_conv(x, feat_cache[idx])\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n\n                    x = x.reshape(b, 2, c, t, h, w)\n                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),\n                                    3)\n                    x = x.reshape(b, c, t * 2, h, w)\n        t = x.shape[2]\n        x = rearrange(x, 'b c t h w -> (b t) c h w')\n        x = self.resample(x)\n        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)\n\n        if self.mode == 'downsample3d':\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = x.clone()\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -1:, :, :].clone()\n                    x = self.time_conv(\n                        torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n        return x, feat_cache, feat_idx\n\n    def init_weight(self, conv):\n        conv_weight = conv.weight\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        one_matrix = torch.eye(c1, c2)\n        init_matrix = one_matrix\n        nn.init.zeros_(conv_weight)\n        conv_weight.data[:, :, 1, 0, 0] = init_matrix\n        conv.weight.data.copy_(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n    def init_weight2(self, conv):\n        conv_weight = conv.weight.data\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        init_matrix = torch.eye(c1 // 2, c2)\n        conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix\n        conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix\n        conv.weight.data.copy_(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n\n\ndef patchify(x, patch_size):\n    if patch_size == 1:\n        return x\n    if x.dim() == 4:\n        x = rearrange(x, \"b c (h q) (w r) -> b (c r q) h w\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(x,\n                      \"b c f (h q) (w r) -> b (c r q) f h w\",\n                      q=patch_size,\n                      r=patch_size)\n    else:\n        raise ValueError(f\"Invalid input shape: {x.shape}\")\n    return x\n\n\ndef unpatchify(x, patch_size):\n    if patch_size == 1:\n        return x\n    if x.dim() == 4:\n        x = rearrange(x, \"b (c r q) h w -> b c (h q) (w r)\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(x,\n                      \"b (c r q) f h w -> b c f (h q) (w r)\",\n                      q=patch_size,\n                      r=patch_size)\n    return x\n\n\nclass Resample38(Resample):\n\n    def __init__(self, dim, mode):\n        assert mode in (\n            \"none\",\n            \"upsample2d\",\n            \"upsample3d\",\n            \"downsample2d\",\n            \"downsample3d\",\n        )\n        super(Resample, self).__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim, 3, padding=1),\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim, 3, padding=1),\n            )\n            self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n            self.time_conv = CausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)\n            )\n        else:\n            self.resample = nn.Identity()\n\nclass ResidualBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, dropout=0.0):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # layers\n        self.residual = nn.Sequential(\n            RMS_norm(in_dim, images=False), nn.SiLU(),\n            CausalConv3d(in_dim, out_dim, 3, padding=1),\n            RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),\n            CausalConv3d(out_dim, out_dim, 3, padding=1))\n        self.shortcut = CausalConv3d(in_dim, out_dim, 1) \\\n            if in_dim != out_dim else nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        h = self.shortcut(x)\n        for layer in self.residual:\n            if check_is_instance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x + h, feat_cache, feat_idx\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    Causal self-attention with a single head.\n    \"\"\"\n\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = RMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n        # zero out the last layer params\n        nn.init.zeros_(self.proj.weight)\n\n    def forward(self, x):\n        identity = x\n        b, c, t, h, w = x.size()\n        x = rearrange(x, 'b c t h w -> (b t) c h w')\n        x = self.norm(x)\n        # compute query, key, value\n        q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(\n            0, 1, 3, 2).contiguous().chunk(3, dim=-1)\n\n        # apply attention\n        x = F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            #attn_mask=block_causal_mask(q, block_size=h * w)\n        )\n        x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)\n\n        # output\n        x = self.proj(x)\n        x = rearrange(x, '(b t) c h w-> b c t h w', t=t)\n        return x + identity\n\n\nclass AvgDown3D(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert in_channels * self.factor % out_channels == 0\n        self.group_size = in_channels * self.factor // out_channels\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t\n        pad = (0, 0, 0, 0, pad_t, 0)\n        x = F.pad(x, pad)\n        B, C, T, H, W = x.shape\n        x = x.view(\n            B,\n            C,\n            T // self.factor_t,\n            self.factor_t,\n            H // self.factor_s,\n            self.factor_s,\n            W // self.factor_s,\n            self.factor_s,\n        )\n        x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()\n        x = x.view(\n            B,\n            C * self.factor,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.view(\n            B,\n            self.out_channels,\n            self.group_size,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.mean(dim=2)\n        return x\n\n\nclass DupUp3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert out_channels * self.factor % in_channels == 0\n        self.repeats = out_channels * self.factor // in_channels\n\n    def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:\n        x = x.repeat_interleave(self.repeats, dim=1)\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            self.factor_t,\n            self.factor_s,\n            self.factor_s,\n            x.size(2),\n            x.size(3),\n            x.size(4),\n        )\n        x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            x.size(2) * self.factor_t,\n            x.size(4) * self.factor_s,\n            x.size(6) * self.factor_s,\n        )\n        if first_chunk:\n            x = x[:, :, self.factor_t - 1 :, :, :]\n        return x\n\n\nclass Down_ResidualBlock(nn.Module):\n    def __init__(\n        self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False\n    ):\n        super().__init__()\n\n        # Shortcut path with downsample\n        self.avg_shortcut = AvgDown3D(\n            in_dim,\n            out_dim,\n            factor_t=2 if temperal_downsample else 1,\n            factor_s=2 if down_flag else 1,\n        )\n\n        # Main path with residual blocks and downsample\n        downsamples = []\n        for _ in range(mult):\n            downsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n\n        # Add the final downsample block\n        if down_flag:\n            mode = \"downsample3d\" if temperal_downsample else \"downsample2d\"\n            downsamples.append(Resample38(out_dim, mode=mode))\n\n        self.downsamples = nn.Sequential(*downsamples)\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        x_copy = x.clone()\n        for module in self.downsamples:\n            x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)\n\n        return x + self.avg_shortcut(x_copy), feat_cache, feat_idx\n\n\nclass Up_ResidualBlock(nn.Module):\n    def __init__(\n        self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False\n    ):\n        super().__init__()\n        # Shortcut path with upsample\n        if up_flag:\n            self.avg_shortcut = DupUp3D(\n                in_dim,\n                out_dim,\n                factor_t=2 if temperal_upsample else 1,\n                factor_s=2 if up_flag else 1,\n            )\n        else:\n            self.avg_shortcut = None\n\n        # Main path with residual blocks and upsample\n        upsamples = []\n        for _ in range(mult):\n            upsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n\n        # Add the final upsample block\n        if up_flag:\n            mode = \"upsample3d\" if temperal_upsample else \"upsample2d\"\n            upsamples.append(Resample38(out_dim, mode=mode))\n\n        self.upsamples = nn.Sequential(*upsamples)\n\n    def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):\n        x_main = x.clone()\n        for module in self.upsamples:\n            x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)\n        if self.avg_shortcut is not None:\n            x_shortcut = self.avg_shortcut(x, first_chunk)\n            return x_main + x_shortcut, feat_cache, feat_idx\n        else:\n            return x_main, feat_cache, feat_idx\n\n\nclass Encoder3d(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[True, True, False],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)\n\n        # downsample blocks\n        downsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            for _ in range(num_res_blocks):\n                downsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n                if scale in attn_scales:\n                    downsamples.append(AttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # downsample block\n            if i != len(dim_mult) - 1:\n                mode = 'downsample3d' if temperal_downsample[\n                    i] else 'downsample2d'\n                downsamples.append(Resample(out_dim, mode=mode))\n                scale /= 2.0\n        self.downsamples = nn.Sequential(*downsamples)\n\n        # middle blocks\n        self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),\n                                    AttentionBlock(out_dim),\n                                    ResidualBlock(out_dim, out_dim, dropout))\n\n        # output blocks\n        self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),\n                                  CausalConv3d(out_dim, z_dim, 3, padding=1))\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([\n                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                        cache_x.device), cache_x\n                ],\n                                    dim=2)\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## downsamples\n        for layer in self.downsamples:\n            if feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        for layer in self.middle:\n            if check_is_instance(layer, ResidualBlock) and feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if check_is_instance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x, feat_cache, feat_idx\n\n\nclass Encoder3d_38(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[False, True, True],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)\n\n        # downsample blocks\n        downsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            t_down_flag = (\n                temperal_downsample[i] if i < len(temperal_downsample) else False\n            )\n            downsamples.append(\n                Down_ResidualBlock(\n                    in_dim=in_dim,\n                    out_dim=out_dim,\n                    dropout=dropout,\n                    mult=num_res_blocks,\n                    temperal_downsample=t_down_flag,\n                    down_flag=i != len(dim_mult) - 1,\n                )\n            )\n            scale /= 2.0\n        self.downsamples = nn.Sequential(*downsamples)\n\n        # middle blocks\n        self.middle = nn.Sequential(\n            ResidualBlock(out_dim, out_dim, dropout),\n            AttentionBlock(out_dim),\n            ResidualBlock(out_dim, out_dim, dropout),\n        )\n\n        # # output blocks\n        self.head = nn.Sequential(\n            RMS_norm(out_dim, images=False),\n            nn.SiLU(),\n            CausalConv3d(out_dim, z_dim, 3, padding=1),\n        )\n\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## downsamples\n        for layer in self.downsamples:\n            if feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        for layer in self.middle:\n            if isinstance(layer, ResidualBlock) and feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    cache_x = torch.cat(\n                        [\n                            feat_cache[idx][:, :, -1, :, :]\n                            .unsqueeze(2)\n                            .to(cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n\n        return x, feat_cache, feat_idx\n\n\nclass Decoder3d(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_upsample=[False, True, True],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2**(len(dim_mult) - 2)\n\n        # init block\n        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),\n                                    AttentionBlock(dims[0]),\n                                    ResidualBlock(dims[0], dims[0], dropout))\n\n        # upsample blocks\n        upsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            if i == 1 or i == 2 or i == 3:\n                in_dim = in_dim // 2\n            for _ in range(num_res_blocks + 1):\n                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n                if scale in attn_scales:\n                    upsamples.append(AttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # upsample block\n            if i != len(dim_mult) - 1:\n                mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'\n                upsamples.append(Resample(out_dim, mode=mode))\n                scale *= 2.0\n        self.upsamples = nn.Sequential(*upsamples)\n\n        # output blocks\n        self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),\n                                  CausalConv3d(out_dim, 3, 3, padding=1))\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        ## conv1\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([\n                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                        cache_x.device), cache_x\n                ],\n                                    dim=2)\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## middle\n        for layer in self.middle:\n            if check_is_instance(layer, ResidualBlock) and feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## upsamples\n        for layer in self.upsamples:\n            if feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if check_is_instance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x, feat_cache, feat_idx\n\n\n\nclass Decoder3d_38(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_upsample=[False, True, True],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2 ** (len(dim_mult) - 2)\n        # init block\n        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),\n                                    AttentionBlock(dims[0]),\n                                    ResidualBlock(dims[0], dims[0], dropout))\n\n        # upsample blocks\n        upsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False\n            upsamples.append(\n                Up_ResidualBlock(in_dim=in_dim,\n                                 out_dim=out_dim,\n                                 dropout=dropout,\n                                 mult=num_res_blocks + 1,\n                                 temperal_upsample=t_up_flag,\n                                 up_flag=i != len(dim_mult) - 1))\n        self.upsamples = nn.Sequential(*upsamples)\n\n        # output blocks\n        self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),\n                                  CausalConv3d(out_dim, 12, 3, padding=1))\n\n\n    def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        for layer in self.middle:\n            if check_is_instance(layer, ResidualBlock) and feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## upsamples\n        for layer in self.upsamples:\n            if feat_cache is not None:\n                x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if check_is_instance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    cache_x = torch.cat(\n                        [\n                            feat_cache[idx][:, :, -1, :, :]\n                            .unsqueeze(2)\n                            .to(cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x, feat_cache, feat_idx\n\n\ndef count_conv3d(model):\n    count = 0\n    for m in model.modules():\n        if isinstance(m, CausalConv3d):\n            count += 1\n    return count\n\n\nclass VideoVAE_(nn.Module):\n\n    def __init__(self,\n                 dim=96,\n                 z_dim=16,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[False, True, True],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n\n        # modules\n        self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,\n                                 attn_scales, self.temperal_downsample, dropout)\n        self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.conv2 = CausalConv3d(z_dim, z_dim, 1)\n        self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,\n                                 attn_scales, self.temperal_upsample, dropout)\n\n    def forward(self, x):\n        mu, log_var = self.encode(x)\n        z = self.reparameterize(mu, log_var)\n        x_recon = self.decode(z)\n        return x_recon, mu, log_var\n\n    def encode(self, x, scale):\n        self.clear_cache()\n        ## cache\n        t = x.shape[2]\n        iter_ = 1 + (t - 1) // 4\n\n        for i in range(iter_):\n            self._enc_conv_idx = [0]\n            if i == 0:\n                out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],\n                                   feat_cache=self._enc_feat_map,\n                                   feat_idx=self._enc_conv_idx)\n            else:\n                out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],\n                                    feat_cache=self._enc_feat_map,\n                                    feat_idx=self._enc_conv_idx)\n                out = torch.cat([out, out_], 2)\n        mu, log_var = self.conv1(out).chunk(2, dim=1)\n        if isinstance(scale[0], torch.Tensor):\n            scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]\n            mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            scale = scale.to(dtype=mu.dtype, device=mu.device)\n            mu = (mu - scale[0]) * scale[1]\n        return mu\n\n    def decode(self, z, scale):\n        self.clear_cache()\n        # z: [b,c,t,h,w]\n        if isinstance(scale[0], torch.Tensor):\n            scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]\n            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            scale = scale.to(dtype=z.dtype, device=z.device)\n            z = z / scale[1] + scale[0]\n        iter_ = z.shape[2]\n        x = self.conv2(z)\n        for i in range(iter_):\n            self._conv_idx = [0]\n            if i == 0:\n                out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],\n                                   feat_cache=self._feat_map,\n                                   feat_idx=self._conv_idx)\n            else:\n                out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],\n                                    feat_cache=self._feat_map,\n                                    feat_idx=self._conv_idx)\n                out = torch.cat([out, out_], 2) # may add tensor offload\n        return out\n\n    def reparameterize(self, mu, log_var):\n        std = torch.exp(0.5 * log_var)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def sample(self, imgs, deterministic=False):\n        mu, log_var = self.encode(imgs)\n        if deterministic:\n            return mu\n        std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))\n        return mu + std * torch.randn_like(std)\n\n    def clear_cache(self):\n        self._conv_num = count_conv3d(self.decoder)\n        self._conv_idx = [0]\n        self._feat_map = [None] * self._conv_num\n        # cache encode\n        self._enc_conv_num = count_conv3d(self.encoder)\n        self._enc_conv_idx = [0]\n        self._enc_feat_map = [None] * self._enc_conv_num\n\n\nclass WanVideoVAE(nn.Module):\n\n    def __init__(self, z_dim=16):\n        super().__init__()\n\n        mean = [\n            -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,\n            0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921\n        ]\n        std = [\n            2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,\n            3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160\n        ]\n        self.mean = torch.tensor(mean)\n        self.std = torch.tensor(std)\n        self.scale = [self.mean, 1.0 / self.std]\n\n        # init model\n        self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)\n        self.upsampling_factor = 8\n        self.z_dim = z_dim\n\n\n    def build_1d_mask(self, length, left_bound, right_bound, border_width):\n        x = torch.ones((length,))\n        if not left_bound:\n            x[:border_width] = (torch.arange(border_width) + 1) / border_width\n        if not right_bound:\n            x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))\n        return x\n\n\n    def build_mask(self, data, is_bound, border_width):\n        _, _, _, H, W = data.shape\n        h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])\n        w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])\n\n        h = repeat(h, \"H -> H W\", H=H, W=W)\n        w = repeat(w, \"W -> H W\", H=H, W=W)\n\n        mask = torch.stack([h, w]).min(dim=0).values\n        mask = rearrange(mask, \"H W -> 1 1 1 H W\")\n        return mask\n\n\n    def tiled_decode(self, hidden_states, device, tile_size, tile_stride):\n        _, _, T, H, W = hidden_states.shape\n        size_h, size_w = tile_size\n        stride_h, stride_w = tile_stride\n\n        # Split tasks\n        tasks = []\n        for h in range(0, H, stride_h):\n            if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue\n            for w in range(0, W, stride_w):\n                if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue\n                h_, w_ = h + size_h, w + size_w\n                tasks.append((h, h_, w, w_))\n\n        data_device = \"cpu\"\n        computation_device = device\n\n        out_T = T * 4 - 3\n        weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)\n        values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)\n\n        for h, h_, w, w_ in tqdm(tasks, desc=\"VAE decoding\"):\n            hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)\n            hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)\n\n            mask = self.build_mask(\n                hidden_states_batch,\n                is_bound=(h==0, h_>=H, w==0, w_>=W),\n                border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)\n            ).to(dtype=hidden_states.dtype, device=data_device)\n\n            target_h = h * self.upsampling_factor\n            target_w = w * self.upsampling_factor\n            values[\n                :,\n                :,\n                :,\n                target_h:target_h + hidden_states_batch.shape[3],\n                target_w:target_w + hidden_states_batch.shape[4],\n            ] += hidden_states_batch * mask\n            weight[\n                :,\n                :,\n                :,\n                target_h: target_h + hidden_states_batch.shape[3],\n                target_w: target_w + hidden_states_batch.shape[4],\n            ] += mask\n        values = values / weight\n        values = values.clamp_(-1, 1)\n        return values\n\n\n    def tiled_encode(self, video, device, tile_size, tile_stride):\n        _, _, T, H, W = video.shape\n        size_h, size_w = tile_size\n        stride_h, stride_w = tile_stride\n\n        # Split tasks\n        tasks = []\n        for h in range(0, H, stride_h):\n            if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue\n            for w in range(0, W, stride_w):\n                if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue\n                h_, w_ = h + size_h, w + size_w\n                tasks.append((h, h_, w, w_))\n\n        data_device = \"cpu\"\n        computation_device = device\n\n        out_T = (T + 3) // 4\n        weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)\n        values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)\n\n        for h, h_, w, w_ in tqdm(tasks, desc=\"VAE encoding\"):\n            hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)\n            hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)\n\n            mask = self.build_mask(\n                hidden_states_batch,\n                is_bound=(h==0, h_>=H, w==0, w_>=W),\n                border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)\n            ).to(dtype=video.dtype, device=data_device)\n\n            target_h = h // self.upsampling_factor\n            target_w = w // self.upsampling_factor\n            values[\n                :,\n                :,\n                :,\n                target_h:target_h + hidden_states_batch.shape[3],\n                target_w:target_w + hidden_states_batch.shape[4],\n            ] += hidden_states_batch * mask\n            weight[\n                :,\n                :,\n                :,\n                target_h: target_h + hidden_states_batch.shape[3],\n                target_w: target_w + hidden_states_batch.shape[4],\n            ] += mask\n        values = values / weight\n        return values\n\n\n    def single_encode(self, video, device):\n        video = video.to(device)\n        x = self.model.encode(video, self.scale)\n        return x\n\n\n    def single_decode(self, hidden_state, device):\n        hidden_state = hidden_state.to(device)\n        video = self.model.decode(hidden_state, self.scale)\n        return video.clamp_(-1, 1)\n\n\n    def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):\n        videos = [video.to(\"cpu\") for video in videos]\n        hidden_states = []\n        for video in videos:\n            video = video.unsqueeze(0)\n            if tiled:\n                tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)\n                tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)\n                hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)\n            else:\n                hidden_state = self.single_encode(video, device)\n            hidden_state = hidden_state.squeeze(0)\n            hidden_states.append(hidden_state)\n        hidden_states = torch.stack(hidden_states)\n        return hidden_states\n\n\n    def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):\n        hidden_states = [hidden_state.to(\"cpu\") for hidden_state in hidden_states]\n        videos = []\n        for hidden_state in hidden_states:\n            hidden_state = hidden_state.unsqueeze(0)\n            if tiled:\n                video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)\n            else:\n                video = self.single_decode(hidden_state, device)\n            video = video.squeeze(0)\n            videos.append(video)\n        videos = torch.stack(videos)\n        return videos\n\n\n    def encode_framewise(self, videos, device):\n        hidden_states = []\n        for i in range(videos.shape[2]):\n            hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))\n        hidden_states = torch.concat(hidden_states, dim=2)\n        return hidden_states\n    \n\n    def decode_framewise(self, hidden_states, device):\n        video = []\n        for i in range(hidden_states.shape[2]):\n            video.append(self.single_decode(hidden_states[:, :, i:i+1], device))\n        video = torch.concat(video, dim=2)\n        return video\n\n\n    @staticmethod\n    def state_dict_converter():\n        return WanVideoVAEStateDictConverter()\n\n\nclass WanVideoVAEStateDictConverter:\n\n    def __init__(self):\n        pass\n\n    def from_civitai(self, state_dict):\n        state_dict_ = {}\n        if 'model_state' in state_dict:\n            state_dict = state_dict['model_state']\n        for name in state_dict:\n            state_dict_['model.' + name] = state_dict[name]\n        return state_dict_\n\n\nclass VideoVAE38_(VideoVAE_):\n\n    def __init__(self,\n                 dim=160,\n                 z_dim=48,\n                 dec_dim=256,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[False, True, True],\n                 dropout=0.0):\n        super(VideoVAE_, self).__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n\n        # modules\n        self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,\n                                    attn_scales, self.temperal_downsample, dropout)\n        self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.conv2 = CausalConv3d(z_dim, z_dim, 1)\n        self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,\n                                    attn_scales, self.temperal_upsample, dropout)\n\n\n    def encode(self, x, scale):\n        self.clear_cache()\n        x = patchify(x, patch_size=2)\n        t = x.shape[2]\n        iter_ = 1 + (t - 1) // 4\n        for i in range(iter_):\n            self._enc_conv_idx = [0]\n            if i == 0:\n                out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],\n                                   feat_cache=self._enc_feat_map,\n                                   feat_idx=self._enc_conv_idx)\n            else:\n                out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],\n                                    feat_cache=self._enc_feat_map,\n                                    feat_idx=self._enc_conv_idx)\n                out = torch.cat([out, out_], 2)\n        mu, log_var = self.conv1(out).chunk(2, dim=1)\n        if isinstance(scale[0], torch.Tensor):\n            scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]\n            mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            scale = scale.to(dtype=mu.dtype, device=mu.device)\n            mu = (mu - scale[0]) * scale[1]\n        self.clear_cache()\n        return mu\n\n\n    def decode(self, z, scale):\n        self.clear_cache()\n        if isinstance(scale[0], torch.Tensor):\n            scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]\n            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            scale = scale.to(dtype=z.dtype, device=z.device)\n            z = z / scale[1] + scale[0]\n        iter_ = z.shape[2]\n        x = self.conv2(z)\n        for i in range(iter_):\n            self._conv_idx = [0]\n            if i == 0:\n                out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],\n                                   feat_cache=self._feat_map,\n                                   feat_idx=self._conv_idx,\n                                   first_chunk=True)\n            else:\n                out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],\n                                    feat_cache=self._feat_map,\n                                    feat_idx=self._conv_idx)\n                out = torch.cat([out, out_], 2)\n        out = unpatchify(out, patch_size=2)\n        self.clear_cache()\n        return out\n\n\nclass WanVideoVAE38(WanVideoVAE):\n\n    def __init__(self, z_dim=48, dim=160):\n        super(WanVideoVAE, self).__init__()\n\n        mean = [\n            -0.2289, -0.0052, -0.1323, -0.2339, -0.2799,  0.0174,  0.1838,  0.1557,\n            -0.1382,  0.0542,  0.2813,  0.0891,  0.1570, -0.0098,  0.0375, -0.1825,\n            -0.2246, -0.1207, -0.0698,  0.5109,  0.2665, -0.2108, -0.2158,  0.2502,\n            -0.2055, -0.0322,  0.1109,  0.1567, -0.0729,  0.0899, -0.2799, -0.1230,\n            -0.0313, -0.1649,  0.0117,  0.0723, -0.2839, -0.2083, -0.0520,  0.3748,\n            0.0152,  0.1957,  0.1433, -0.2944,  0.3573, -0.0548, -0.1681, -0.0667\n        ]\n        std = [\n            0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,\n            0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,\n            0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,\n            0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,\n            0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,\n            0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744\n        ]\n        self.mean = torch.tensor(mean)\n        self.std = torch.tensor(std)\n        self.scale = [self.mean, 1.0 / self.std]\n\n        # init model\n        self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)\n        self.upsampling_factor = 16\n        self.z_dim = z_dim\n"
  },
  {
    "path": "diffsynth/models/wantodance.py",
    "content": "from inspect import isfunction\nfrom math import log, pi\n\nimport torch\nfrom einops import rearrange, repeat\nfrom torch import einsum, nn\n\nfrom typing import Any, Callable, List, Optional, Union\nfrom torch import Tensor\nimport torch.nn.functional as F\n\n# helper functions\n\n\ndef exists(val):\n    return val is not None\n\n\ndef broadcat(tensors, dim=-1):\n    num_tensors = len(tensors)\n    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))\n    assert len(shape_lens) == 1, \"tensors must all have the same number of dimensions\"\n    shape_len = list(shape_lens)[0]\n\n    dim = (dim + shape_len) if dim < 0 else dim\n    dims = list(zip(*map(lambda t: list(t.shape), tensors)))\n\n    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]\n    assert all(\n        [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]\n    ), \"invalid dimensions for broadcastable concatentation\"\n    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))\n    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))\n    expanded_dims.insert(dim, (dim, dims[dim]))\n    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))\n    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))\n    return torch.cat(tensors, dim=dim)\n\n\n# rotary embedding helper functions\n\n\ndef rotate_half(x):\n    x = rearrange(x, \"... (d r) -> ... d r\", r=2)\n    x1, x2 = x.unbind(dim=-1)\n    x = torch.stack((-x2, x1), dim=-1)\n    return rearrange(x, \"... d r -> ... (d r)\")\n\n\ndef apply_rotary_emb(freqs, t, start_index=0):\n    freqs = freqs.to(t)\n    rot_dim = freqs.shape[-1]\n    end_index = start_index + rot_dim\n    assert (\n        rot_dim <= t.shape[-1]\n    ), f\"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}\"\n    t_left, t, t_right = (\n        t[..., :start_index],\n        t[..., start_index:end_index],\n        t[..., end_index:],\n    )\n    t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())\n    return torch.cat((t_left, t, t_right), dim=-1)\n\n\n# learned rotation helpers\n\n\ndef apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):\n    if exists(freq_ranges):\n        rotations = einsum(\"..., f -> ... f\", rotations, freq_ranges)\n        rotations = rearrange(rotations, \"... r f -> ... (r f)\")\n\n    rotations = repeat(rotations, \"... n -> ... (n r)\", r=2)\n    return apply_rotary_emb(rotations, t, start_index=start_index)\n\n\n# classes\n\n\nclass WanToDanceRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        dim,\n        custom_freqs=None,\n        freqs_for=\"lang\",\n        theta=10000,\n        max_freq=10,\n        num_freqs=1,\n        learned_freq=False,\n    ):\n        super().__init__()\n        if exists(custom_freqs):\n            freqs = custom_freqs\n        elif freqs_for == \"lang\":\n            freqs = 1.0 / (\n                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)\n            )\n        elif freqs_for == \"pixel\":\n            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi\n        elif freqs_for == \"constant\":\n            freqs = torch.ones(num_freqs).float()\n        else:\n            raise ValueError(f\"unknown modality {freqs_for}\")\n\n        self.cache = dict()\n\n        if learned_freq:\n            self.freqs = nn.Parameter(freqs)\n        else:\n            self.register_buffer(\"freqs\", freqs, persistent=False)\n\n    def rotate_queries_or_keys(self, t, seq_dim=-2):\n        device = t.device\n        seq_len = t.shape[seq_dim]\n        freqs = self.forward(\n            lambda: torch.arange(seq_len, device=device), cache_key=seq_len\n        )\n        return apply_rotary_emb(freqs, t)\n\n    def forward(self, t, cache_key=None):\n        if exists(cache_key) and cache_key in self.cache:\n            return self.cache[cache_key]\n\n        if isfunction(t):\n            t = t()\n\n        # freqs = self.freqs\n        freqs = self.freqs.to(t.device)\n\n        freqs = torch.einsum(\"..., f -> ... f\", t.type(freqs.dtype), freqs)\n        freqs = repeat(freqs, \"... n -> ... (n r)\", r=2)\n\n        if exists(cache_key):\n            self.cache[cache_key] = freqs\n\n        return freqs\n\n\nclass WanToDanceMusicEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        d_model: int,\n        nhead: int,\n        dim_feedforward: int = 2048,\n        dropout: float = 0.1,\n        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,\n        layer_norm_eps: float = 1e-5,\n        batch_first: bool = False,\n        norm_first: bool = True,\n        device=None,\n        dtype=None,\n        rotary=None,\n    ) -> None:\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(\n            d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype\n        )\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm_first = norm_first\n        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.activation = activation\n\n        self.rotary = rotary\n        self.use_rotary = rotary is not None\n\n    # self-attention block\n    def _sa_block(\n        self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]\n    ) -> Tensor:\n        qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x\n        x = self.self_attn(\n            qk,\n            qk,\n            x,\n            attn_mask=attn_mask,\n            key_padding_mask=key_padding_mask,\n            need_weights=False,\n        )[0]\n        return self.dropout1(x)\n\n    # feed forward block\n    def _ff_block(self, x: Tensor) -> Tensor:\n        x = self.linear2(self.dropout(self.activation(self.linear1(x))))\n        return self.dropout2(x)\n\n    def forward(\n        self,\n        src: Tensor,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n    ) -> Tensor:\n        x = src\n        if self.norm_first:\n            self.norm1.to(device=x.device)\n            self.norm2.to(device=x.device)\n            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)\n            x = x + self._ff_block(self.norm2(x))\n        else:\n            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))\n            x = self.norm2(x + self._ff_block(x))\n        return x"
  },
  {
    "path": "diffsynth/models/wav2vec.py",
    "content": "import math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\n\ndef get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):\n    required_duration = num_sample / target_fps\n    required_origin_frames = int(np.ceil(required_duration * original_fps))\n    if required_duration > total_frames / original_fps:\n        raise ValueError(\"required_duration must be less than video length\")\n\n    if not fixed_start is None and fixed_start >= 0:\n        start_frame = fixed_start\n    else:\n        max_start = total_frames - required_origin_frames\n        if max_start < 0:\n            raise ValueError(\"video length is too short\")\n        start_frame = np.random.randint(0, max_start + 1)\n    start_time = start_frame / original_fps\n\n    end_time = start_time + required_duration\n    time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)\n\n    frame_indices = np.round(np.array(time_points) * original_fps).astype(int)\n    frame_indices = np.clip(frame_indices, 0, total_frames - 1)\n    return frame_indices\n\n\ndef linear_interpolation(features, input_fps, output_fps, output_len=None):\n    \"\"\"\n    features: shape=[1, T, 512]\n    input_fps: fps for audio, f_a\n    output_fps: fps for video, f_m\n    output_len: video length\n    \"\"\"\n    features = features.transpose(1, 2)\n    seq_len = features.shape[2] / float(input_fps)\n    if output_len is None:\n        output_len = int(seq_len * output_fps)\n    output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')  # [1, 512, output_len]\n    return output_features.transpose(1, 2)\n\n\nclass WanS2VAudioEncoder(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        from transformers import Wav2Vec2ForCTC, Wav2Vec2Config\n        config = {\n            \"_name_or_path\": \"facebook/wav2vec2-large-xlsr-53\",\n            \"activation_dropout\": 0.05,\n            \"apply_spec_augment\": True,\n            \"architectures\": [\"Wav2Vec2ForCTC\"],\n            \"attention_dropout\": 0.1,\n            \"bos_token_id\": 1,\n            \"conv_bias\": True,\n            \"conv_dim\": [512, 512, 512, 512, 512, 512, 512],\n            \"conv_kernel\": [10, 3, 3, 3, 3, 2, 2],\n            \"conv_stride\": [5, 2, 2, 2, 2, 2, 2],\n            \"ctc_loss_reduction\": \"mean\",\n            \"ctc_zero_infinity\": True,\n            \"do_stable_layer_norm\": True,\n            \"eos_token_id\": 2,\n            \"feat_extract_activation\": \"gelu\",\n            \"feat_extract_dropout\": 0.0,\n            \"feat_extract_norm\": \"layer\",\n            \"feat_proj_dropout\": 0.05,\n            \"final_dropout\": 0.0,\n            \"hidden_act\": \"gelu\",\n            \"hidden_dropout\": 0.05,\n            \"hidden_size\": 1024,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 4096,\n            \"layer_norm_eps\": 1e-05,\n            \"layerdrop\": 0.05,\n            \"mask_channel_length\": 10,\n            \"mask_channel_min_space\": 1,\n            \"mask_channel_other\": 0.0,\n            \"mask_channel_prob\": 0.0,\n            \"mask_channel_selection\": \"static\",\n            \"mask_feature_length\": 10,\n            \"mask_feature_prob\": 0.0,\n            \"mask_time_length\": 10,\n            \"mask_time_min_space\": 1,\n            \"mask_time_other\": 0.0,\n            \"mask_time_prob\": 0.05,\n            \"mask_time_selection\": \"static\",\n            \"model_type\": \"wav2vec2\",\n            \"num_attention_heads\": 16,\n            \"num_conv_pos_embedding_groups\": 16,\n            \"num_conv_pos_embeddings\": 128,\n            \"num_feat_extract_layers\": 7,\n            \"num_hidden_layers\": 24,\n            \"pad_token_id\": 0,\n            \"transformers_version\": \"4.7.0.dev0\",\n            \"vocab_size\": 33\n        }\n        self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))\n        self.video_rate = 30\n\n    def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):\n        input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors=\"pt\").input_values.to(dtype=dtype, device=device)\n\n        # retrieve logits & take argmax\n        res = self.model(input_values, output_hidden_states=True)\n        if return_all_layers:\n            feat = torch.cat(res.hidden_states)\n        else:\n            feat = res.hidden_states[-1]\n        feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)\n        return feat\n\n    def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):\n        num_layers, audio_frame_num, audio_dim = audio_embed.shape\n\n        if num_layers > 1:\n            return_all_layers = True\n        else:\n            return_all_layers = False\n\n        min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1\n\n        bucket_num = min_batch_num * batch_frames\n        batch_idx = [stride * i for i in range(bucket_num)]\n        batch_audio_eb = []\n        for bi in batch_idx:\n            if bi < audio_frame_num:\n                audio_sample_stride = 2\n                chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))\n                chosen_idx = [0 if c < 0 else c for c in chosen_idx]\n                chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]\n\n                if return_all_layers:\n                    frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)\n                else:\n                    frame_audio_embed = audio_embed[0][chosen_idx].flatten()\n            else:\n                frame_audio_embed = \\\n                torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \\\n                    else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)\n            batch_audio_eb.append(frame_audio_embed)\n        batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)\n\n        return batch_audio_eb, min_batch_num\n\n    def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):\n        num_layers, audio_frame_num, audio_dim = audio_embed.shape\n\n        if num_layers > 1:\n            return_all_layers = True\n        else:\n            return_all_layers = False\n\n        scale = self.video_rate / fps\n\n        min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1\n\n        bucket_num = min_batch_num * batch_frames\n        padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num\n        batch_idx = get_sample_indices(\n            original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0\n        )\n        batch_audio_eb = []\n        audio_sample_stride = int(self.video_rate / fps)\n        for bi in batch_idx:\n            if bi < audio_frame_num:\n\n                chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))\n                chosen_idx = [0 if c < 0 else c for c in chosen_idx]\n                chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]\n\n                if return_all_layers:\n                    frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)\n                else:\n                    frame_audio_embed = audio_embed[0][chosen_idx].flatten()\n            else:\n                frame_audio_embed = \\\n                torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \\\n                    else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)\n            batch_audio_eb.append(frame_audio_embed)\n        batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)\n\n        return batch_audio_eb, min_batch_num\n\n    def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):\n        audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)\n        audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)\n        audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)\n        audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]\n        return audio_embeds\n"
  },
  {
    "path": "diffsynth/models/z_image_controlnet.py",
    "content": "from .z_image_dit import ZImageTransformerBlock\nfrom ..core.gradient import gradient_checkpoint_forward\nfrom torch.nn.utils.rnn import pad_sequence\nimport torch\nfrom torch import nn\n\n\nclass ZImageControlTransformerBlock(ZImageTransformerBlock):\n    def __init__(\n        self, \n        layer_id: int = 1000,\n        dim: int = 3840,\n        n_heads: int = 30,\n        n_kv_heads: int = 30,\n        norm_eps: float = 1e-5,\n        qk_norm: bool = True,\n        modulation = True,\n        block_id = 0\n    ):\n        super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)\n        self.block_id = block_id\n        if block_id == 0:\n            self.before_proj = nn.Linear(self.dim, self.dim)\n        self.after_proj = nn.Linear(self.dim, self.dim)\n\n    def forward(self, c, x, **kwargs):\n        if self.block_id == 0:\n            c = self.before_proj(c) + x\n            all_c = []\n        else:\n            all_c = list(torch.unbind(c))\n            c = all_c.pop(-1)\n\n        c = super().forward(c, **kwargs)\n        c_skip = self.after_proj(c)\n        all_c += [c_skip, c]\n        c = torch.stack(all_c)\n        return c\n\n\nclass ZImageControlNet(torch.nn.Module):\n    def __init__(\n        self,\n        control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),\n        control_in_dim=33,\n        dim=3840,\n        n_refiner_layers=2,\n    ):\n        super().__init__()\n        self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])\n        self.control_all_x_embedder = nn.ModuleDict({\"2-1\": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})\n        self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])\n        self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}\n\n    def forward_layers(\n        self,\n        x,\n        cap_feats,\n        control_context,\n        control_context_item_seqlens,\n        kwargs,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        bsz = len(control_context)\n        # unified\n        cap_item_seqlens = [len(_) for _ in cap_feats]\n        control_context_unified = []\n        for i in range(bsz):\n            control_context_len = control_context_item_seqlens[i]\n            cap_len = cap_item_seqlens[i]\n            control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))\n        c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)\n\n        # arguments\n        new_kwargs = dict(x=x)\n        new_kwargs.update(kwargs)\n        \n        for layer in self.control_layers:\n            c = gradient_checkpoint_forward(\n                layer,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                c=c, **new_kwargs\n            )\n \n        hints = torch.unbind(c)[:-1]\n        return hints\n    \n    def forward_refiner(\n        self,\n        dit,\n        x,\n        cap_feats,\n        control_context,\n        kwargs,\n        t=None,\n        patch_size=2,\n        f_patch_size=1,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        # embeddings\n        bsz = len(control_context)\n        device = control_context[0].device\n        (\n            control_context,\n            control_context_size,\n            control_context_pos_ids,\n            control_context_inner_pad_mask,\n        ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))\n\n        # control_context embed & refine\n        control_context_item_seqlens = [len(_) for _ in control_context]\n        assert all(_ % 2 == 0 for _ in control_context_item_seqlens)\n        control_context_max_item_seqlen = max(control_context_item_seqlens)\n\n        control_context = torch.cat(control_context, dim=0)\n        control_context = self.control_all_x_embedder[f\"{patch_size}-{f_patch_size}\"](control_context)\n\n        # Match t_embedder output dtype to control_context for layerwise casting compatibility\n        adaln_input = t.type_as(control_context)\n        control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)\n        control_context = list(control_context.split(control_context_item_seqlens, dim=0))\n        control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))\n\n        control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)\n        control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)\n        control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(control_context_item_seqlens):\n            control_context_attn_mask[i, :seq_len] = 1\n        c = control_context\n\n        # arguments\n        new_kwargs = dict(\n            x=x, \n            attn_mask=control_context_attn_mask,\n            freqs_cis=control_context_freqs_cis, \n            adaln_input=adaln_input,\n        )\n        new_kwargs.update(kwargs)\n        \n        for layer in self.control_noise_refiner:\n            c = gradient_checkpoint_forward(\n                layer,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                c=c, **new_kwargs\n            )\n \n        hints = torch.unbind(c)[:-1]\n        control_context = torch.unbind(c)[-1]\n\n        return hints, control_context, control_context_item_seqlens"
  },
  {
    "path": "diffsynth/models/z_image_dit.py",
    "content": "import math\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_sequence\n\nfrom .general_modules import RMSNorm\nfrom ..core.attention import attention_forward\nfrom ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type\nfrom ..core.gradient import gradient_checkpoint_forward\n\n\nADALN_EMBED_DIM = 256\nSEQ_MULTI_OF = 32\nX_PAD_DIM = 64\n\n\nclass TimestepEmbedder(nn.Module):\n    def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):\n        super().__init__()\n        if mid_size is None:\n            mid_size = out_size\n        self.mlp = nn.Sequential(\n            nn.Linear(\n                frequency_embedding_size,\n                mid_size,\n                bias=True,\n            ),\n            nn.SiLU(),\n            nn.Linear(\n                mid_size,\n                out_size,\n                bias=True,\n            ),\n        )\n\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        with torch.amp.autocast(get_device_type(), enabled=False):\n            half = dim // 2\n            freqs = torch.exp(\n                -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half\n            )\n            args = t[:, None].float() * freqs[None]\n            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n            if dim % 2:\n                embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n            return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        t_emb = self.mlp(t_freq.to(torch.bfloat16))\n        return t_emb\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim: int, hidden_dim: int):\n        super().__init__()\n        self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n        self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n        self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n\n    def _forward_silu_gating(self, x1, x3):\n        return F.silu(x1) * x3\n\n    def forward(self, x):\n        return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))\n\n\nclass Attention(torch.nn.Module):\n\n    def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):\n        super().__init__()\n        dim_inner = head_dim * num_heads\n        kv_dim = kv_dim if kv_dim is not None else q_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)\n        self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)\n        self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])\n\n        self.norm_q = RMSNorm(head_dim, eps=1e-5)\n        self.norm_k = RMSNorm(head_dim, eps=1e-5)\n    \n    # Apply RoPE\n    def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:\n        with torch.amp.autocast(get_device_type(), enabled=False):\n            x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))\n            freqs_cis = freqs_cis.unsqueeze(2)\n            x_out = torch.view_as_real(x * freqs_cis).flatten(3)\n            return x_out.type_as(x_in)  # todo\n\n    def forward(self, hidden_states, freqs_cis, attention_mask):\n        query = self.to_q(hidden_states)\n        key = self.to_k(hidden_states)\n        value = self.to_v(hidden_states)\n\n        query = query.unflatten(-1, (self.num_heads, -1))\n        key = key.unflatten(-1, (self.num_heads, -1))\n        value = value.unflatten(-1, (self.num_heads, -1))\n\n        # Apply Norms\n        if self.norm_q is not None:\n            query = self.norm_q(query)\n        if self.norm_k is not None:\n            key = self.norm_k(key)\n\n        if freqs_cis is not None:\n            query = self.apply_rotary_emb(query, freqs_cis)\n            key = self.apply_rotary_emb(key, freqs_cis)\n\n        # Cast to correct dtype\n        dtype = query.dtype\n        query, key = query.to(dtype), key.to(dtype)\n\n        # Compute joint attention\n        hidden_states = attention_forward(\n            query,\n            key,\n            value,\n            q_pattern=\"b s n d\", k_pattern=\"b s n d\", v_pattern=\"b s n d\", out_pattern=\"b s n d\",\n            attn_mask=attention_mask,\n        )\n\n        # Reshape back\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(dtype)\n\n        output = self.to_out[0](hidden_states)\n        if len(self.to_out) > 1:  # dropout\n            output = self.to_out[1](output)\n\n        return output\n\n\ndef select_per_token(\n    value_noisy: torch.Tensor,\n    value_clean: torch.Tensor,\n    noise_mask: torch.Tensor,\n    seq_len: int,\n) -> torch.Tensor:\n    noise_mask_expanded = noise_mask.unsqueeze(-1)  # (batch, seq_len, 1)\n    return torch.where(\n        noise_mask_expanded == 1,\n        value_noisy.unsqueeze(1).expand(-1, seq_len, -1),\n        value_clean.unsqueeze(1).expand(-1, seq_len, -1),\n    )\n\n\nclass ZImageTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        dim: int,\n        n_heads: int,\n        n_kv_heads: int,\n        norm_eps: float,\n        qk_norm: bool,\n        modulation=True,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.head_dim = dim // n_heads\n\n        # Refactored to use diffusers Attention with custom processor\n        # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm\n        self.attention = Attention(\n            q_dim=dim,\n            num_heads=n_heads,\n            head_dim=dim // n_heads,\n        )\n\n        self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))\n        self.layer_id = layer_id\n\n        self.attention_norm1 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)\n\n        self.attention_norm2 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)\n\n        self.modulation = modulation\n        if modulation:\n            self.adaLN_modulation = nn.Sequential(\n                nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),\n            )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        attn_mask: torch.Tensor,\n        freqs_cis: torch.Tensor,\n        adaln_input: Optional[torch.Tensor] = None,\n        noise_mask: Optional[torch.Tensor] = None,\n        adaln_noisy: Optional[torch.Tensor] = None,\n        adaln_clean: Optional[torch.Tensor] = None,\n    ):\n        if self.modulation:\n            seq_len = x.shape[1]\n\n            if noise_mask is not None:\n                # Per-token modulation: different modulation for noisy/clean tokens\n                mod_noisy = self.adaLN_modulation(adaln_noisy)\n                mod_clean = self.adaLN_modulation(adaln_clean)\n\n                scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)\n                scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)\n\n                gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()\n                gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()\n\n                scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy\n                scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean\n\n                scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)\n                scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)\n                gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)\n                gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)\n            else:\n                # Global modulation: same modulation for all tokens (avoid double select)\n                mod = self.adaLN_modulation(adaln_input)\n                scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)\n                gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()\n                scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp\n\n            # Attention block\n            attn_out = self.attention(\n                self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis\n            )\n            x = x + gate_msa * self.attention_norm2(attn_out)\n\n            # FFN block\n            x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))\n        else:\n            # Attention block\n            attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)\n            x = x + self.attention_norm2(attn_out)\n\n            # FFN block\n            x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))\n\n        return x\n\n\nclass FinalLayer(nn.Module):\n    def __init__(self, hidden_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, out_channels, bias=True)\n\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),\n        )\n\n    def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):\n        seq_len = x.shape[1]\n\n        if noise_mask is not None:\n            # Per-token modulation\n            scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)\n            scale_clean = 1.0 + self.adaLN_modulation(c_clean)\n            scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)\n        else:\n            # Original global modulation\n            assert c is not None, \"Either c or (c_noisy, c_clean) must be provided\"\n            scale = 1.0 + self.adaLN_modulation(c)\n            scale = scale.unsqueeze(1)\n\n        x = self.norm_final(x) * scale\n        x = self.linear(x)\n        return x\n\n\nclass RopeEmbedder:\n    def __init__(\n        self,\n        theta: float = 256.0,\n        axes_dims: List[int] = (16, 56, 56),\n        axes_lens: List[int] = (64, 128, 128),\n    ):\n        self.theta = theta\n        self.axes_dims = axes_dims\n        self.axes_lens = axes_lens\n        assert len(axes_dims) == len(axes_lens), \"axes_dims and axes_lens must have the same length\"\n        self.freqs_cis = None\n\n    @staticmethod\n    def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):\n        with torch.device(\"cpu\"):\n            freqs_cis = []\n            for i, (d, e) in enumerate(zip(dim, end)):\n                freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device=\"cpu\") / d))\n                timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)\n                freqs = torch.outer(timestep, freqs).float()\n                freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)  # complex64\n                freqs_cis.append(freqs_cis_i)\n\n            return freqs_cis\n\n    def __call__(self, ids: torch.Tensor):\n        assert ids.ndim == 2\n        assert ids.shape[-1] == len(self.axes_dims)\n        device = ids.device\n\n        if self.freqs_cis is None:\n            self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)\n            self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]\n\n        result = []\n        for i in range(len(self.axes_dims)):\n            index = ids[:, i]\n            if IS_NPU_AVAILABLE:\n                result.append(torch.index_select(self.freqs_cis[i], 0, index))\n            else:\n                result.append(self.freqs_cis[i][index])\n        return torch.cat(result, dim=-1)\n\n\nclass ZImageDiT(nn.Module):\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"ZImageTransformerBlock\"]\n\n    def __init__(\n        self,\n        all_patch_size=(2,),\n        all_f_patch_size=(1,),\n        in_channels=16,\n        dim=3840,\n        n_layers=30,\n        n_refiner_layers=2,\n        n_heads=30,\n        n_kv_heads=30,\n        norm_eps=1e-5,\n        qk_norm=True,\n        cap_feat_dim=2560,\n        rope_theta=256.0,\n        t_scale=1000.0,\n        axes_dims=[32, 48, 48],\n        axes_lens=[1024, 512, 512],\n        siglip_feat_dim=None,\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels\n        self.all_patch_size = all_patch_size\n        self.all_f_patch_size = all_f_patch_size\n        self.dim = dim\n        self.n_heads = n_heads\n\n        self.rope_theta = rope_theta\n        self.t_scale = t_scale\n        self.gradient_checkpointing = False\n\n        assert len(all_patch_size) == len(all_f_patch_size)\n\n        all_x_embedder = {}\n        all_final_layer = {}\n        for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):\n            x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)\n            all_x_embedder[f\"{patch_size}-{f_patch_size}\"] = x_embedder\n\n            final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)\n            all_final_layer[f\"{patch_size}-{f_patch_size}\"] = final_layer\n\n        self.all_x_embedder = nn.ModuleDict(all_x_embedder)\n        self.all_final_layer = nn.ModuleDict(all_final_layer)\n        self.noise_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(\n                    1000 + layer_id,\n                    dim,\n                    n_heads,\n                    n_kv_heads,\n                    norm_eps,\n                    qk_norm,\n                    modulation=True,\n                )\n                for layer_id in range(n_refiner_layers)\n            ]\n        )\n        self.context_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(\n                    layer_id,\n                    dim,\n                    n_heads,\n                    n_kv_heads,\n                    norm_eps,\n                    qk_norm,\n                    modulation=False,\n                )\n                for layer_id in range(n_refiner_layers)\n            ]\n        )\n        self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)\n        self.cap_embedder = nn.Sequential(\n            RMSNorm(cap_feat_dim, eps=norm_eps),\n            nn.Linear(cap_feat_dim, dim, bias=True),\n        )\n\n        # Optional SigLIP components (for Omni variant)\n        self.siglip_feat_dim = siglip_feat_dim\n        if siglip_feat_dim is not None:\n            self.siglip_embedder = nn.Sequential(\n                RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)\n            )\n            self.siglip_refiner = nn.ModuleList(\n                [\n                    ZImageTransformerBlock(\n                        2000 + layer_id,\n                        dim,\n                        n_heads,\n                        n_kv_heads,\n                        norm_eps,\n                        qk_norm,\n                        modulation=False,\n                    )\n                    for layer_id in range(n_refiner_layers)\n                ]\n            )\n            self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))\n        else:\n            self.siglip_embedder = None\n            self.siglip_refiner = None\n            self.siglip_pad_token = None\n\n        self.x_pad_token = nn.Parameter(torch.empty((1, dim)))\n        self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))\n\n        self.layers = nn.ModuleList(\n            [\n                ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)\n                for layer_id in range(n_layers)\n            ]\n        )\n        head_dim = dim // n_heads\n        assert head_dim == sum(axes_dims)\n        self.axes_dims = axes_dims\n        self.axes_lens = axes_lens\n\n        self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)\n\n    def unpatchify(\n        self,\n        x: List[torch.Tensor],\n        size: List[Tuple],\n        patch_size = 2,\n        f_patch_size = 1,\n        x_pos_offsets: Optional[List[Tuple[int, int]]] = None,\n    ) -> List[torch.Tensor]:\n        pH = pW = patch_size\n        pF = f_patch_size\n        bsz = len(x)\n        assert len(size) == bsz\n\n        if x_pos_offsets is not None:\n            # Omni: extract target image from unified sequence (cond_images + target)\n            result = []\n            for i in range(bsz):\n                unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]\n                cu_len = 0\n                x_item = None\n                for j in range(len(size[i])):\n                    if size[i][j] is None:\n                        ori_len = 0\n                        pad_len = SEQ_MULTI_OF\n                        cu_len += pad_len + ori_len\n                    else:\n                        F, H, W = size[i][j]\n                        ori_len = (F // pF) * (H // pH) * (W // pW)\n                        pad_len = (-ori_len) % SEQ_MULTI_OF\n                        x_item = (\n                            unified_x[cu_len : cu_len + ori_len]\n                            .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)\n                            .permute(6, 0, 3, 1, 4, 2, 5)\n                            .reshape(self.out_channels, F, H, W)\n                        )\n                        cu_len += ori_len + pad_len\n                result.append(x_item)  # Return only the last (target) image\n            return result\n        else:\n            # Original mode: simple unpatchify\n            for i in range(bsz):\n                F, H, W = size[i]\n                ori_len = (F // pF) * (H // pH) * (W // pW)\n                # \"f h w pf ph pw c -> c (f pf) (h ph) (w pw)\"\n                x[i] = (\n                    x[i][:ori_len]\n                    .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)\n                    .permute(6, 0, 3, 1, 4, 2, 5)\n                    .reshape(self.out_channels, F, H, W)\n                )\n            return x\n\n    @staticmethod\n    def create_coordinate_grid(size, start=None, device=None):\n        if start is None:\n            start = (0 for _ in size)\n\n        axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]\n        grids = torch.meshgrid(axes, indexing=\"ij\")\n        return torch.stack(grids, dim=-1)\n\n    def patchify_and_embed(\n        self,\n        all_image: List[torch.Tensor],\n        all_cap_feats: List[torch.Tensor],\n        patch_size: int = 2,\n        f_patch_size: int = 1,\n    ):\n        pH = pW = patch_size\n        pF = f_patch_size\n        device = all_image[0].device\n\n        all_image_out = []\n        all_image_size = []\n        all_image_pos_ids = []\n        all_image_pad_mask = []\n        all_cap_pos_ids = []\n        all_cap_pad_mask = []\n        all_cap_feats_out = []\n\n        for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):\n            ### Process Caption\n            cap_ori_len = len(cap_feat)\n            cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF\n            # padded position ids\n            cap_padded_pos_ids = self.create_coordinate_grid(\n                size=(cap_ori_len + cap_padding_len, 1, 1),\n                start=(1, 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            all_cap_pos_ids.append(cap_padded_pos_ids)\n            # pad mask\n            all_cap_pad_mask.append(\n                torch.cat(\n                    [\n                        torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),\n                        torch.ones((cap_padding_len,), dtype=torch.bool, device=device),\n                    ],\n                    dim=0,\n                )\n            )\n            # padded feature\n            cap_padded_feat = torch.cat(\n                [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],\n                dim=0,\n            )\n            all_cap_feats_out.append(cap_padded_feat)\n\n            ### Process Image\n            C, F, H, W = image.size()\n            all_image_size.append((F, H, W))\n            F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n\n            image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)\n            # \"c f pf h ph w pw -> (f h w) (pf ph pw c)\"\n            image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)\n\n            image_ori_len = len(image)\n            image_padding_len = (-image_ori_len) % SEQ_MULTI_OF\n\n            image_ori_pos_ids = self.create_coordinate_grid(\n                size=(F_tokens, H_tokens, W_tokens),\n                start=(cap_ori_len + cap_padding_len + 1, 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            image_padding_pos_ids = (\n                self.create_coordinate_grid(\n                    size=(1, 1, 1),\n                    start=(0, 0, 0),\n                    device=device,\n                )\n                .flatten(0, 2)\n                .repeat(image_padding_len, 1)\n            )\n            image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)\n            all_image_pos_ids.append(image_padded_pos_ids)\n            # pad mask\n            all_image_pad_mask.append(\n                torch.cat(\n                    [\n                        torch.zeros((image_ori_len,), dtype=torch.bool, device=device),\n                        torch.ones((image_padding_len,), dtype=torch.bool, device=device),\n                    ],\n                    dim=0,\n                )\n            )\n            # padded feature\n            image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)\n            all_image_out.append(image_padded_feat)\n\n        return all_image_out, all_cap_feats_out, {\n            \"x_size\": all_image_size,\n            \"x_pos_ids\": all_image_pos_ids,\n            \"cap_pos_ids\": all_cap_pos_ids,\n            \"x_pad_mask\": all_image_pad_mask,\n            \"cap_pad_mask\": all_cap_pad_mask\n        }\n    # (\n    #         all_img_out,\n    #         all_cap_out,\n    #         all_img_size,\n    #         all_img_pos_ids,\n    #         all_cap_pos_ids,\n    #         all_img_pad_mask,\n    #         all_cap_pad_mask,\n    #     )\n\n    def patchify_controlnet(\n        self,\n        all_image: List[torch.Tensor],\n        patch_size: int = 2,\n        f_patch_size: int = 1,\n        cap_padding_len: int = None,\n    ):\n        pH = pW = patch_size\n        pF = f_patch_size\n        device = all_image[0].device\n\n        all_image_out = []\n        all_image_size = []\n        all_image_pos_ids = []\n        all_image_pad_mask = []\n\n        for i, image in enumerate(all_image):\n            ### Process Image\n            C, F, H, W = image.size()\n            all_image_size.append((F, H, W))\n            F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n\n            image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)\n            # \"c f pf h ph w pw -> (f h w) (pf ph pw c)\"\n            image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)\n\n            image_ori_len = len(image)\n            image_padding_len = (-image_ori_len) % SEQ_MULTI_OF\n\n            image_ori_pos_ids = self.create_coordinate_grid(\n                size=(F_tokens, H_tokens, W_tokens),\n                start=(cap_padding_len + 1, 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            image_padding_pos_ids = (\n                self.create_coordinate_grid(\n                    size=(1, 1, 1),\n                    start=(0, 0, 0),\n                    device=device,\n                )\n                .flatten(0, 2)\n                .repeat(image_padding_len, 1)\n            )\n            image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)\n            all_image_pos_ids.append(image_padded_pos_ids)\n            # pad mask\n            all_image_pad_mask.append(\n                torch.cat(\n                    [\n                        torch.zeros((image_ori_len,), dtype=torch.bool, device=device),\n                        torch.ones((image_padding_len,), dtype=torch.bool, device=device),\n                    ],\n                    dim=0,\n                )\n            )\n            # padded feature\n            image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)\n            all_image_out.append(image_padded_feat)\n\n        return (\n            all_image_out,\n            all_image_size,\n            all_image_pos_ids,\n            all_image_pad_mask,\n        )\n    \n    def _prepare_sequence(\n        self,\n        feats: List[torch.Tensor],\n        pos_ids: List[torch.Tensor],\n        inner_pad_mask: List[torch.Tensor],\n        pad_token: torch.nn.Parameter,\n        noise_mask: Optional[List[List[int]]] = None,\n        device: torch.device = None,\n    ):\n        \"\"\"Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.\"\"\"\n        item_seqlens = [len(f) for f in feats]\n        max_seqlen = max(item_seqlens)\n        bsz = len(feats)\n\n        # Pad token\n        feats_cat = torch.cat(feats, dim=0)\n        feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)\n        feats = list(feats_cat.split(item_seqlens, dim=0))\n\n        # RoPE\n        freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))\n\n        # Pad to batch\n        feats = pad_sequence(feats, batch_first=True, padding_value=0.0)\n        freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]\n\n        # Attention mask\n        attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(item_seqlens):\n            attn_mask[i, :seq_len] = 1\n\n        # Noise mask\n        noise_mask_tensor = None\n        if noise_mask is not None:\n            noise_mask_tensor = pad_sequence(\n                [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],\n                batch_first=True,\n                padding_value=0,\n            )[:, : feats.shape[1]]\n\n        return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor\n    \n    def _build_unified_sequence(\n        self,\n        x: torch.Tensor,\n        x_freqs: torch.Tensor,\n        x_seqlens: List[int],\n        x_noise_mask: Optional[List[List[int]]],\n        cap: torch.Tensor,\n        cap_freqs: torch.Tensor,\n        cap_seqlens: List[int],\n        cap_noise_mask: Optional[List[List[int]]],\n        siglip: Optional[torch.Tensor],\n        siglip_freqs: Optional[torch.Tensor],\n        siglip_seqlens: Optional[List[int]],\n        siglip_noise_mask: Optional[List[List[int]]],\n        omni_mode: bool,\n        device: torch.device,\n    ):\n        \"\"\"Build unified sequence: x, cap, and optionally siglip.\n        Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]\n        \"\"\"\n        bsz = len(x_seqlens)\n        unified = []\n        unified_freqs = []\n        unified_noise_mask = []\n\n        for i in range(bsz):\n            x_len, cap_len = x_seqlens[i], cap_seqlens[i]\n\n            if omni_mode:\n                # Omni: [cap, x, siglip]\n                if siglip is not None and siglip_seqlens is not None:\n                    sig_len = siglip_seqlens[i]\n                    unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))\n                    unified_freqs.append(\n                        torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])\n                    )\n                    unified_noise_mask.append(\n                        torch.tensor(\n                            cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device\n                        )\n                    )\n                else:\n                    unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))\n                    unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))\n                    unified_noise_mask.append(\n                        torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)\n                    )\n            else:\n                # Basic: [x, cap]\n                unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))\n                unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))\n\n        # Compute unified seqlens\n        if omni_mode:\n            if siglip is not None and siglip_seqlens is not None:\n                unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]\n            else:\n                unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]\n        else:\n            unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]\n\n        max_seqlen = max(unified_seqlens)\n\n        # Pad to batch\n        unified = pad_sequence(unified, batch_first=True, padding_value=0.0)\n        unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)\n\n        # Attention mask\n        attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(unified_seqlens):\n            attn_mask[i, :seq_len] = 1\n\n        # Noise mask\n        noise_mask_tensor = None\n        if omni_mode:\n            noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[\n                :, : unified.shape[1]\n            ]\n\n        return unified, unified_freqs, attn_mask, noise_mask_tensor\n    \n    def _pad_with_ids(\n        self,\n        feat: torch.Tensor,\n        pos_grid_size: Tuple,\n        pos_start: Tuple,\n        device: torch.device,\n        noise_mask_val: Optional[int] = None,\n    ):\n        \"\"\"Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.\"\"\"\n        ori_len = len(feat)\n        pad_len = (-ori_len) % SEQ_MULTI_OF\n        total_len = ori_len + pad_len\n\n        # Pos IDs\n        ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)\n        if pad_len > 0:\n            pad_pos_ids = (\n                self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)\n                .flatten(0, 2)\n                .repeat(pad_len, 1)\n            )\n            pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)\n            padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)\n            pad_mask = torch.cat(\n                [\n                    torch.zeros(ori_len, dtype=torch.bool, device=device),\n                    torch.ones(pad_len, dtype=torch.bool, device=device),\n                ]\n            )\n        else:\n            pos_ids = ori_pos_ids\n            padded_feat = feat\n            pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)\n\n        noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None  # token level\n        return padded_feat, pos_ids, pad_mask, total_len, noise_mask\n    \n    def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):\n        \"\"\"Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).\"\"\"\n        pH, pW, pF = patch_size, patch_size, f_patch_size\n        C, F, H, W = image.size()\n        F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n        image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)\n        image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)\n        return image, (F, H, W), (F_tokens, H_tokens, W_tokens)\n    \n    def patchify_and_embed_omni(\n        self,\n        all_x: List[List[torch.Tensor]],\n        all_cap_feats: List[List[torch.Tensor]],\n        all_siglip_feats: List[List[torch.Tensor]],\n        patch_size: int = 2,\n        f_patch_size: int = 1,\n        images_noise_mask: List[List[int]] = None,\n    ):\n        \"\"\"Patchify for omni mode: multiple images per batch item with noise masks.\"\"\"\n        bsz = len(all_x)\n        device = all_x[0][-1].device\n        dtype = all_x[0][-1].dtype\n\n        all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []\n        all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []\n        all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []\n\n        for i in range(bsz):\n            num_images = len(all_x[i])\n            cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []\n            cap_end_pos = []\n            cap_cu_len = 1\n\n            # Process captions\n            for j, cap_item in enumerate(all_cap_feats[i]):\n                noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1\n                cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(\n                    cap_item,\n                    (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),\n                    (cap_cu_len, 0, 0),\n                    device,\n                    noise_val,\n                )\n                cap_feats_list.append(cap_out)\n                cap_pos_list.append(cap_pos)\n                cap_mask_list.append(cap_mask)\n                cap_lens.append(cap_len)\n                cap_noise.extend(cap_nm)\n                cap_cu_len += len(cap_item)\n                cap_end_pos.append(cap_cu_len)\n                cap_cu_len += 2  # for image vae and siglip tokens\n\n            all_cap_out.append(torch.cat(cap_feats_list, dim=0))\n            all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))\n            all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))\n            all_cap_len.append(cap_lens)\n            all_cap_noise_mask.append(cap_noise)\n\n            # Process images\n            x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []\n            for j, x_item in enumerate(all_x[i]):\n                noise_val = images_noise_mask[i][j]\n                if x_item is not None:\n                    x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)\n                    x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(\n                        x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val\n                    )\n                    x_size.append(size)\n                else:\n                    x_len = SEQ_MULTI_OF\n                    x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)\n                    x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)\n                    x_mask = torch.ones(x_len, dtype=torch.bool, device=device)\n                    x_nm = [noise_val] * x_len\n                    x_size.append(None)\n                x_feats_list.append(x_out)\n                x_pos_list.append(x_pos)\n                x_mask_list.append(x_mask)\n                x_lens.append(x_len)\n                x_noise.extend(x_nm)\n\n            all_x_out.append(torch.cat(x_feats_list, dim=0))\n            all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))\n            all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))\n            all_x_size.append(x_size)\n            all_x_len.append(x_lens)\n            all_x_noise_mask.append(x_noise)\n\n            # Process siglip\n            if all_siglip_feats[i] is None:\n                all_sig_len.append([0] * num_images)\n                all_sig_out.append(None)\n            else:\n                sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []\n                for j, sig_item in enumerate(all_siglip_feats[i]):\n                    noise_val = images_noise_mask[i][j]\n                    if sig_item is not None:\n                        sig_H, sig_W, sig_C = sig_item.size()\n                        sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)\n                        sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(\n                            sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val\n                        )\n                        # Scale position IDs to match x resolution\n                        if x_size[j] is not None:\n                            sig_pos = sig_pos.float()\n                            sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)\n                            sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)\n                            sig_pos = sig_pos.to(torch.int32)\n                    else:\n                        sig_len = SEQ_MULTI_OF\n                        sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)\n                        sig_pos = (\n                            self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)\n                        )\n                        sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)\n                        sig_nm = [noise_val] * sig_len\n                    sig_feats_list.append(sig_out)\n                    sig_pos_list.append(sig_pos)\n                    sig_mask_list.append(sig_mask)\n                    sig_lens.append(sig_len)\n                    sig_noise.extend(sig_nm)\n\n                all_sig_out.append(torch.cat(sig_feats_list, dim=0))\n                all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))\n                all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))\n                all_sig_len.append(sig_lens)\n                all_sig_noise_mask.append(sig_noise)\n\n        # Compute x position offsets\n        all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]\n\n        return (\n            all_x_out,\n            all_cap_out,\n            all_sig_out,\n            all_x_size,\n            all_x_pos_ids,\n            all_cap_pos_ids,\n            all_sig_pos_ids,\n            all_x_pad_mask,\n            all_cap_pad_mask,\n            all_sig_pad_mask,\n            all_x_pos_offsets,\n            all_x_noise_mask,\n            all_cap_noise_mask,\n            all_sig_noise_mask,\n        )\n        return all_x_out, all_cap_out, all_sig_out, {\n            \"x_size\": x_size,\n            \"x_pos_ids\": all_x_pos_ids,\n            \"cap_pos_ids\": all_cap_pos_ids,\n            \"sig_pos_ids\": all_sig_pos_ids,\n            \"x_pad_mask\": all_x_pad_mask,\n            \"cap_pad_mask\": all_cap_pad_mask,\n            \"sig_pad_mask\": all_sig_pad_mask,\n            \"x_pos_offsets\": all_x_pos_offsets,\n            \"x_noise_mask\": all_x_noise_mask,\n            \"cap_noise_mask\": all_cap_noise_mask,\n            \"sig_noise_mask\": all_sig_noise_mask,\n        }\n\n    def forward(\n        self,\n        x: List[torch.Tensor],\n        t,\n        cap_feats: List[torch.Tensor],\n        siglip_feats = None,\n        image_noise_mask = None,\n        patch_size=2,\n        f_patch_size=1,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size\n        omni_mode = isinstance(x[0], list)\n        device = x[0][-1].device if omni_mode else x[0].device\n\n        if omni_mode:\n            # Dual embeddings: noisy (t) and clean (t=1)\n            t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])\n            t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])\n            adaln_input = None\n        else:\n            # Single embedding for all tokens\n            adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])\n            t_noisy = t_clean = None\n\n        # Patchify\n        if omni_mode:\n            (\n                x,\n                cap_feats,\n                siglip_feats,\n                x_size,\n                x_pos_ids,\n                cap_pos_ids,\n                siglip_pos_ids,\n                x_pad_mask,\n                cap_pad_mask,\n                siglip_pad_mask,\n                x_pos_offsets,\n                x_noise_mask,\n                cap_noise_mask,\n                siglip_noise_mask,\n            ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)\n        else:\n            (\n                x,\n                cap_feats,\n                x_size,\n                x_pos_ids,\n                cap_pos_ids,\n                x_pad_mask,\n                cap_pad_mask,\n            ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)\n            x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None\n\n        # x embed & refine\n        x_seqlens = [len(xi) for xi in x]\n        x = self.all_x_embedder[f\"{patch_size}-{f_patch_size}\"](torch.cat(x, dim=0))  # embed\n        x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(\n            list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device\n        )\n\n        for layer in self.noise_refiner:\n            x = gradient_checkpoint_forward(\n                layer,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,\n            )\n\n        # Cap embed & refine\n        cap_seqlens = [len(ci) for ci in cap_feats]\n        cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0))  # embed\n        cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(\n            list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device\n        )\n\n        for layer in self.context_refiner:\n            cap_feats = gradient_checkpoint_forward(\n                layer,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                x=cap_feats,\n                attn_mask=cap_mask,\n                freqs_cis=cap_freqs,\n            )\n\n        # Siglip embed & refine\n        siglip_seqlens = siglip_freqs = None\n        if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:\n            siglip_seqlens = [len(si) for si in siglip_feats]\n            siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0))  # embed\n            siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(\n                list(siglip_feats.split(siglip_seqlens, dim=0)),\n                siglip_pos_ids,\n                siglip_pad_mask,\n                self.siglip_pad_token,\n                None,\n                device,\n            )\n\n            for layer in self.siglip_refiner:\n                siglip_feats = gradient_checkpoint_forward(\n                    layer,\n                    use_gradient_checkpointing=use_gradient_checkpointing,\n                    use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                    x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,\n                )\n\n        # Unified sequence\n        unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(\n            x,\n            x_freqs,\n            x_seqlens,\n            x_noise_mask,\n            cap_feats,\n            cap_freqs,\n            cap_seqlens,\n            cap_noise_mask,\n            siglip_feats,\n            siglip_freqs,\n            siglip_seqlens,\n            siglip_noise_mask,\n            omni_mode,\n            device,\n        )\n\n        # Main transformer layers\n        for layer_idx, layer in enumerate(self.layers):\n            unified = gradient_checkpoint_forward(\n                layer,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean\n            )\n\n        unified = (\n            self.all_final_layer[f\"{patch_size}-{f_patch_size}\"](\n                unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean\n            )\n            if omni_mode\n            else self.all_final_layer[f\"{patch_size}-{f_patch_size}\"](unified, c=adaln_input)\n        )\n\n        # Unpatchify\n        x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)\n\n        return x\n"
  },
  {
    "path": "diffsynth/models/z_image_image2lora.py",
    "content": "import torch\nfrom .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP\n\n\nclass LoRATrainerBlock(torch.nn.Module):\n    def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix=\"transformer_blocks\"):\n        super().__init__()\n        self.prefix = prefix\n        self.lora_patterns = lora_patterns\n        self.block_id = block_id\n        self.layers = []\n        for name, lora_a_dim, lora_b_dim in self.lora_patterns:\n            self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))\n        self.layers = torch.nn.ModuleList(self.layers)\n        if use_residual:\n            self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)\n        else:\n            self.proj_residual = None\n    \n    def forward(self, x, residual=None):\n        lora = {}\n        if self.proj_residual is not None: residual = self.proj_residual(residual)\n        for lora_pattern, layer in zip(self.lora_patterns, self.layers):\n            name = lora_pattern[0]\n            lora_a, lora_b = layer(x, residual=residual)\n            lora[f\"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight\"] = lora_a\n            lora[f\"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight\"] = lora_b\n        return lora\n\n\nclass ZImageImage2LoRAComponent(torch.nn.Module):\n    def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):\n        super().__init__()\n        self.lora_patterns = lora_patterns\n        self.num_blocks = num_blocks\n        self.blocks = []\n        for lora_patterns in self.lora_patterns:\n            for block_id in range(self.num_blocks):\n                self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))\n        self.blocks = torch.nn.ModuleList(self.blocks)\n        self.residual_scale = 0.05\n        self.use_residual = use_residual\n        \n    def forward(self, x, residual=None):\n        if residual is not None:\n            if self.use_residual:\n                residual = residual * self.residual_scale\n            else:\n                residual = None\n        lora = {}\n        for block in self.blocks:\n            lora.update(block(x, residual))\n        return lora\n\n\nclass ZImageImage2LoRAModel(torch.nn.Module):\n    def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):\n        super().__init__()\n        lora_patterns = [\n            [\n                (\"attention.to_q\", 3840, 3840),\n                (\"attention.to_k\", 3840, 3840),\n                (\"attention.to_v\", 3840, 3840),\n                (\"attention.to_out.0\", 3840, 3840),\n            ],\n            [\n                (\"feed_forward.w1\", 3840, 10240),\n                (\"feed_forward.w2\", 10240, 3840),\n                (\"feed_forward.w3\", 3840, 10240),\n            ],\n        ]\n        config = {\n            \"lora_patterns\": lora_patterns,\n            \"use_residual\": use_residual,\n            \"compress_dim\": compress_dim,\n            \"rank\": rank,\n            \"residual_length\": residual_length,\n            \"residual_mid_dim\": residual_mid_dim,\n        }\n        self.layers_lora = ZImageImage2LoRAComponent(\n            prefix=\"layers\",\n            num_blocks=30,\n            **config,\n        )\n        self.context_refiner_lora = ZImageImage2LoRAComponent(\n            prefix=\"context_refiner\",\n            num_blocks=2,\n            **config,\n        )\n        self.noise_refiner_lora = ZImageImage2LoRAComponent(\n            prefix=\"noise_refiner\",\n            num_blocks=2,\n            **config,\n        )\n        \n    def forward(self, x, residual=None):\n        lora = {}\n        lora.update(self.layers_lora(x, residual=residual))\n        lora.update(self.context_refiner_lora(x, residual=residual))\n        lora.update(self.noise_refiner_lora(x, residual=residual))\n        return lora\n\n    def initialize_weights(self):\n        state_dict = self.state_dict()\n        for name in state_dict:\n            if \".proj_a.\" in name:\n                state_dict[name] = state_dict[name] * 0.3\n            elif \".proj_b.proj_out.\" in name:\n                state_dict[name] = state_dict[name] * 0\n            elif \".proj_residual.proj_out.\" in name:\n                state_dict[name] = state_dict[name] * 0.3\n        self.load_state_dict(state_dict)\n\n\nclass ImageEmb2LoRAWeightCompressed(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, emb_dim, rank):\n        super().__init__()\n        self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))\n        self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))\n        self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)\n        self.rank = rank\n    \n    def forward(self, x):\n        x = self.proj(x).view(self.rank, self.rank)\n        lora_a = x @ self.lora_a\n        lora_b = self.lora_b\n        return lora_a, lora_b\n\n\nclass ZImageImage2LoRAModelCompressed(torch.nn.Module):\n    def __init__(self, emb_dim=1536+4096, rank=32):\n        super().__init__()\n        target_layers = [\n            (\"attention.to_q\", 3840, 3840),\n            (\"attention.to_k\", 3840, 3840),\n            (\"attention.to_v\", 3840, 3840),\n            (\"attention.to_out.0\", 3840, 3840),\n            (\"feed_forward.w1\", 3840, 10240),\n            (\"feed_forward.w2\", 10240, 3840),\n            (\"feed_forward.w3\", 3840, 10240),\n        ]\n        self.lora_patterns = [\n            {\n                \"prefix\": \"layers\",\n                \"num_layers\": 30,\n                \"target_layers\": target_layers,\n            },\n            {\n                \"prefix\": \"context_refiner\",\n                \"num_layers\": 2,\n                \"target_layers\": target_layers,\n            },\n            {\n                \"prefix\": \"noise_refiner\",\n                \"num_layers\": 2,\n                \"target_layers\": target_layers,\n            },\n        ]\n        module_dict = {}\n        for lora_pattern in self.lora_patterns:\n            prefix, num_layers, target_layers = lora_pattern[\"prefix\"], lora_pattern[\"num_layers\"], lora_pattern[\"target_layers\"]\n            for layer_id in range(num_layers):\n                for layer_name, in_dim, out_dim in target_layers:\n                    name = f\"{prefix}.{layer_id}.{layer_name}\".replace(\".\", \"___\")\n                    model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)\n                    module_dict[name] = model\n        self.module_dict = torch.nn.ModuleDict(module_dict)\n\n    def forward(self, x, residual=None):\n        lora = {}\n        for name, module in self.module_dict.items():\n            name = name.replace(\"___\", \".\")\n            name_a, name_b = f\"{name}.lora_A.default.weight\", f\"{name}.lora_B.default.weight\"\n            lora_a, lora_b = module(x)\n            lora[name_a] = lora_a\n            lora[name_b] = lora_b\n        return lora\n\n    def initialize_weights(self):\n        state_dict = self.state_dict()\n        for name in state_dict:\n            if \"lora_b\" in name:\n                state_dict[name] = state_dict[name] * 0\n            elif \"lora_a\" in name:\n                state_dict[name] = state_dict[name] * 0.2\n            elif \"proj.weight\" in name:\n                print(name)\n                state_dict[name] = state_dict[name] * 0.2\n        self.load_state_dict(state_dict)\n"
  },
  {
    "path": "diffsynth/models/z_image_text_encoder.py",
    "content": "from transformers import Qwen3Model, Qwen3Config\nimport torch\n\n\nclass ZImageTextEncoder(torch.nn.Module):\n    def __init__(self, model_size=\"4B\"):\n        super().__init__()\n        config_dict = {\n            \"0.6B\": Qwen3Config(**{\n                \"architectures\": [\n                    \"Qwen3ForCausalLM\"\n                ],\n                \"attention_bias\": False,\n                \"attention_dropout\": 0.0,\n                \"bos_token_id\": 151643,\n                \"eos_token_id\": 151645,\n                \"head_dim\": 128,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 1024,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 3072,\n                \"max_position_embeddings\": 40960,\n                \"max_window_layers\": 28,\n                \"model_type\": \"qwen3\",\n                \"num_attention_heads\": 16,\n                \"num_hidden_layers\": 28,\n                \"num_key_value_heads\": 8,\n                \"rms_norm_eps\": 1e-06,\n                \"rope_scaling\": None,\n                \"rope_theta\": 1000000,\n                \"sliding_window\": None,\n                \"tie_word_embeddings\": True,\n                \"torch_dtype\": \"bfloat16\",\n                \"transformers_version\": \"4.51.0\",\n                \"use_cache\": True,\n                \"use_sliding_window\": False,\n                \"vocab_size\": 151936\n            }),\n            \"4B\": Qwen3Config(**{\n                \"architectures\": [\n                    \"Qwen3ForCausalLM\"\n                ],\n                \"attention_bias\": False,\n                \"attention_dropout\": 0.0,\n                \"bos_token_id\": 151643,\n                \"eos_token_id\": 151645,\n                \"head_dim\": 128,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 2560,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 9728,\n                \"max_position_embeddings\": 40960,\n                \"max_window_layers\": 36,\n                \"model_type\": \"qwen3\",\n                \"num_attention_heads\": 32,\n                \"num_hidden_layers\": 36,\n                \"num_key_value_heads\": 8,\n                \"rms_norm_eps\": 1e-06,\n                \"rope_scaling\": None,\n                \"rope_theta\": 1000000,\n                \"sliding_window\": None,\n                \"tie_word_embeddings\": True,\n                \"torch_dtype\": \"bfloat16\",\n                \"transformers_version\": \"4.51.0\",\n                \"use_cache\": True,\n                \"use_sliding_window\": False,\n                \"vocab_size\": 151936\n            }),\n            \"8B\": Qwen3Config(**{\n                \"architectures\": [\n                    \"Qwen3ForCausalLM\"\n                ],\n                \"attention_bias\": False,\n                \"attention_dropout\": 0.0,\n                \"bos_token_id\": 151643,\n                \"dtype\": \"bfloat16\",\n                \"eos_token_id\": 151645,\n                \"head_dim\": 128,\n                \"hidden_act\": \"silu\",\n                \"hidden_size\": 4096,\n                \"initializer_range\": 0.02,\n                \"intermediate_size\": 12288,\n                \"max_position_embeddings\": 40960,\n                \"max_window_layers\": 36,\n                \"model_type\": \"qwen3\",\n                \"num_attention_heads\": 32,\n                \"num_hidden_layers\": 36,\n                \"num_key_value_heads\": 8,\n                \"rms_norm_eps\": 1e-06,\n                \"rope_scaling\": None,\n                \"rope_theta\": 1000000,\n                \"sliding_window\": None,\n                \"tie_word_embeddings\": False,\n                \"transformers_version\": \"4.56.1\",\n                \"use_cache\": True,\n                \"use_sliding_window\": False,\n                \"vocab_size\": 151936\n            })\n        }\n        config = config_dict[model_size]\n        self.model = Qwen3Model(config)\n    \n    def forward(self, *args, **kwargs):\n        return self.model(*args, **kwargs)\n"
  },
  {
    "path": "diffsynth/pipelines/anima_image.py",
    "content": "import torch, math\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange\nimport numpy as np\nfrom math import prod\nfrom transformers import AutoTokenizer\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput\nfrom ..utils.lora.merge import merge_lora\n\nfrom ..models.anima_dit import AnimaDiT\nfrom ..models.z_image_text_encoder import ZImageTextEncoder\nfrom ..models.wan_video_vae import WanVideoVAE\n\n\nclass AnimaImagePipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"Z-Image\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: AnimaDiT = None\n        self.vae: WanVideoVAE = None\n        self.tokenizer: AutoTokenizer = None\n        self.tokenizer_t5xxl: AutoTokenizer = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            AnimaUnit_ShapeChecker(),\n            AnimaUnit_NoiseInitializer(),\n            AnimaUnit_InputImageEmbedder(),\n            AnimaUnit_PromptEmbedder(),\n        ]\n        self.model_fn = model_fn_anima\n    \n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n        tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"anima_dit\")\n        pipe.vae = model_pool.fetch_model(\"wan_video_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        if tokenizer_t5xxl_config is not None:\n            tokenizer_t5xxl_config.download_if_necessary()\n            pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)\n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 4.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        sigma_shift: float = None,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)\n        \n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"].unsqueeze(2), device=self.device).squeeze(2)\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass AnimaUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\"),\n            output_params=(\"height\", \"width\"),\n        )\n\n    def process(self, pipe: AnimaImagePipeline, height, width):\n        height, width = pipe.check_resize_height_width(height, width)\n        return {\"height\": height, \"width\": width}\n\n\n\nclass AnimaUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\n\nclass AnimaUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: AnimaImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        if isinstance(input_image, list):\n            input_latents = []\n            for image in input_image:\n                image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n                input_latents.append(pipe.vae.encode(image))\n            input_latents = torch.concat(input_latents, dim=0)\n        else:\n            image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\nclass AnimaUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_emb\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def encode_prompt(\n        self,\n        pipe: AnimaImagePipeline,\n        prompt,\n        device = None,\n        max_sequence_length: int = 512,\n    ):\n        if isinstance(prompt, str):\n            prompt = [prompt]\n\n        text_inputs = pipe.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids.to(device)\n        prompt_masks = text_inputs.attention_mask.to(device).bool()\n\n        prompt_embeds = pipe.text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_masks,\n            output_hidden_states=True,\n        ).hidden_states[-1]\n        \n        t5xxl_text_inputs = pipe.tokenizer_t5xxl(\n            prompt,\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)\n\n        return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids\n\n    def process(self, pipe: AnimaImagePipeline, prompt):\n        pipe.load_models_to_device(self.onload_model_names)\n        prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)\n        return {\"prompt_emb\": prompt_embeds, \"t5xxl_ids\": t5xxl_ids}\n\n\ndef model_fn_anima(\n    dit: AnimaDiT = None,\n    latents=None,\n    timestep=None,\n    prompt_emb=None,\n    t5xxl_ids=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs\n):\n    latents = latents.unsqueeze(2)\n    timestep = timestep / 1000\n    model_output = dit(\n        x=latents,\n        timesteps=timestep,\n        context=prompt_emb,\n        t5xxl_ids=t5xxl_ids,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    model_output = model_output.squeeze(2)\n    return model_output\n"
  },
  {
    "path": "diffsynth/pipelines/flux2_image.py",
    "content": "import torch, math, torchvision\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange\nimport numpy as np\nfrom typing import Union, List, Optional, Tuple\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput\n\nfrom transformers import AutoProcessor, AutoTokenizer\nfrom ..models.flux2_text_encoder import Flux2TextEncoder\nfrom ..models.flux2_dit import Flux2DiT\nfrom ..models.flux2_vae import Flux2VAE\nfrom ..models.z_image_text_encoder import ZImageTextEncoder\n\n\nclass Flux2ImagePipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.2\")\n        self.text_encoder: Flux2TextEncoder = None\n        self.text_encoder_qwen3: ZImageTextEncoder = None\n        self.dit: Flux2DiT = None\n        self.vae: Flux2VAE = None\n        self.tokenizer: AutoProcessor = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            Flux2Unit_ShapeChecker(),\n            Flux2Unit_PromptEmbedder(),\n            Flux2Unit_Qwen3PromptEmbedder(),\n            Flux2Unit_NoiseInitializer(),\n            Flux2Unit_InputImageEmbedder(),\n            Flux2Unit_EditImageEmbedder(),\n            Flux2Unit_ImageIDs(),\n        ]\n        self.model_fn = model_fn_flux2\n    \n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"flux2_text_encoder\")\n        pipe.text_encoder_qwen3 = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"flux2_dit\")\n        pipe.vae = model_pool.fetch_model(\"flux2_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        embedded_guidance: float = 4.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Edit\n        edit_image: Union[Image.Image, List[Image.Image]] = None,\n        edit_image_auto_resize: bool = True,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        initial_noise: torch.Tensor = None,\n        # Steps\n        num_inference_steps: int = 30,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)\n\n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale, \"embedded_guidance\": embedded_guidance,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"edit_image\": edit_image, \"edit_image_auto_resize\": edit_image_auto_resize,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device, \"initial_noise\": initial_noise,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        latents = rearrange(inputs_shared[\"latents\"], \"B (H W) C -> B C H W\", H=inputs_shared[\"height\"]//16, W=inputs_shared[\"width\"]//16)\n        image = self.vae.decode(latents)\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass Flux2Unit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\"),\n            output_params=(\"height\", \"width\"),\n        )\n\n    def process(self, pipe: Flux2ImagePipeline, height, width):\n        height, width = pipe.check_resize_height_width(height, width)\n        return {\"height\": height, \"width\": width}\n\n\nclass Flux2Unit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_emb\", \"prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n        self.system_message = \"You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.\"\n\n    def format_text_input(self, prompts: List[str], system_message: str = None):\n        # Remove [IMG] tokens from prompts to avoid Pixtral validation issues\n        # when truncation is enabled. The processor counts [IMG] tokens and fails\n        # if the count changes after truncation.\n        cleaned_txt = [prompt.replace(\"[IMG]\", \"\") for prompt in prompts]\n\n        return [\n            [\n                {\n                    \"role\": \"system\",\n                    \"content\": [{\"type\": \"text\", \"text\": system_message}],\n                },\n                {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": prompt}]},\n            ]\n            for prompt in cleaned_txt\n        ]\n\n    def get_mistral_3_small_prompt_embeds(\n        self,\n        text_encoder,\n        tokenizer,\n        prompt: Union[str, List[str]],\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        max_sequence_length: int = 512,\n        # fmt: off\n        system_message: str = \"You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.\",\n        # fmt: on\n        hidden_states_layers: List[int] = (10, 20, 30),\n    ):\n        dtype = text_encoder.dtype if dtype is None else dtype\n        device = text_encoder.device if device is None else device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        # Format input messages\n        messages_batch = self.format_text_input(prompts=prompt, system_message=system_message)\n\n        # Process all messages at once\n        inputs = tokenizer.apply_chat_template(\n            messages_batch,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_dict=True,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            truncation=True,\n            max_length=max_sequence_length,\n        )\n\n        # Move to device\n        input_ids = inputs[\"input_ids\"].to(device)\n        attention_mask = inputs[\"attention_mask\"].to(device)\n\n        # Forward pass through the model\n        output = text_encoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            output_hidden_states=True,\n            use_cache=False,\n        )\n\n        # Only use outputs from intermediate layers and stack them\n        out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)\n        out = out.to(dtype=dtype, device=device)\n\n        batch_size, num_channels, seq_len, hidden_dim = out.shape\n        prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)\n\n        return prompt_embeds\n    \n    def prepare_text_ids(\n        self,\n        x: torch.Tensor,  # (B, L, D) or (L, D)\n        t_coord: Optional[torch.Tensor] = None,\n    ):\n        B, L, _ = x.shape\n        out_ids = []\n\n        for i in range(B):\n            t = torch.arange(1) if t_coord is None else t_coord[i]\n            h = torch.arange(1)\n            w = torch.arange(1)\n            l = torch.arange(L)\n\n            coords = torch.cartesian_prod(t, h, w, l)\n            out_ids.append(coords)\n\n        return torch.stack(out_ids)\n\n    def encode_prompt(\n        self,\n        text_encoder,\n        tokenizer,\n        prompt: Union[str, List[str]],\n        dtype = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 512,\n        text_encoder_out_layers: Tuple[int] = (10, 20, 30),\n    ):\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_embeds = self.get_mistral_3_small_prompt_embeds(\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                prompt=prompt,\n                dtype=dtype,\n                device=device,\n                max_sequence_length=max_sequence_length,\n                system_message=self.system_message,\n                hidden_states_layers=text_encoder_out_layers,\n            )\n\n        batch_size, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        text_ids = self.prepare_text_ids(prompt_embeds)\n        text_ids = text_ids.to(device)\n        return prompt_embeds, text_ids\n\n    def process(self, pipe: Flux2ImagePipeline, prompt):\n        # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)\n        if pipe.text_encoder_qwen3 is not None:\n            return {}\n        \n        pipe.load_models_to_device(self.onload_model_names)\n        prompt_embeds, text_ids = self.encode_prompt(\n            pipe.text_encoder, pipe.tokenizer, prompt,\n            dtype=pipe.torch_dtype, device=pipe.device,\n        )\n        return {\"prompt_embeds\": prompt_embeds, \"text_ids\": text_ids}\n\n\nclass Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_emb\", \"prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder_qwen3\",)\n        )\n        self.hidden_states_layers = (9, 18, 27)  # Qwen3 layers\n\n    def get_qwen3_prompt_embeds(\n        self,\n        text_encoder: ZImageTextEncoder,\n        tokenizer: AutoTokenizer,\n        prompt: Union[str, List[str]],\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        max_sequence_length: int = 512,\n    ):\n        dtype = text_encoder.dtype if dtype is None else dtype\n        device = text_encoder.device if device is None else device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        all_input_ids = []\n        all_attention_masks = []\n\n        for single_prompt in prompt:\n            messages = [{\"role\": \"user\", \"content\": single_prompt}]\n            text = tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=False,\n            )\n            inputs = tokenizer(\n                text,\n                return_tensors=\"pt\",\n                padding=\"max_length\",\n                truncation=True,\n                max_length=max_sequence_length,\n            )\n\n            all_input_ids.append(inputs[\"input_ids\"])\n            all_attention_masks.append(inputs[\"attention_mask\"])\n\n        input_ids = torch.cat(all_input_ids, dim=0).to(device)\n        attention_mask = torch.cat(all_attention_masks, dim=0).to(device)\n\n        # Forward pass through the model\n        output = text_encoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            output_hidden_states=True,\n            use_cache=False,\n        )\n\n        # Only use outputs from intermediate layers and stack them\n        out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)\n        out = out.to(dtype=dtype, device=device)\n\n        batch_size, num_channels, seq_len, hidden_dim = out.shape\n        prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)\n        return prompt_embeds\n\n    def prepare_text_ids(\n        self,\n        x: torch.Tensor,  # (B, L, D) or (L, D)\n        t_coord: Optional[torch.Tensor] = None,\n    ):\n        B, L, _ = x.shape\n        out_ids = []\n\n        for i in range(B):\n            t = torch.arange(1) if t_coord is None else t_coord[i]\n            h = torch.arange(1)\n            w = torch.arange(1)\n            l = torch.arange(L)\n\n            coords = torch.cartesian_prod(t, h, w, l)\n            out_ids.append(coords)\n\n        return torch.stack(out_ids)\n\n    def encode_prompt(\n        self,\n        text_encoder: ZImageTextEncoder,\n        tokenizer: AutoTokenizer,\n        prompt: Union[str, List[str]],\n        dtype = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 512,\n    ):\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_embeds = self.get_qwen3_prompt_embeds(\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                prompt=prompt,\n                dtype=dtype,\n                device=device,\n                max_sequence_length=max_sequence_length,\n            )\n\n        batch_size, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        text_ids = self.prepare_text_ids(prompt_embeds)\n        text_ids = text_ids.to(device)\n        return prompt_embeds, text_ids\n\n    def process(self, pipe: Flux2ImagePipeline, prompt):\n        # Check if Qwen3 text encoder is available\n        if pipe.text_encoder_qwen3 is None:\n            return {}\n        \n        pipe.load_models_to_device(self.onload_model_names)\n        prompt_embeds, text_ids = self.encode_prompt(\n            pipe.text_encoder_qwen3, pipe.tokenizer, prompt,\n            dtype=pipe.torch_dtype, device=pipe.device,\n        )\n        return {\"prompt_embeds\": prompt_embeds, \"text_ids\": text_ids}\n\n\nclass Flux2Unit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\", \"initial_noise\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device, initial_noise):\n        if initial_noise is not None:\n            noise = initial_noise.clone()\n        else:\n            noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)\n        return {\"noise\": noise}\n\n\nclass Flux2Unit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: Flux2ImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae.encode(image)\n        input_latents = rearrange(input_latents, \"B C H W -> B (H W) C\")\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\nclass Flux2Unit_EditImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"edit_image\", \"edit_image_auto_resize\"),\n            output_params=(\"edit_latents\", \"edit_image_ids\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def calculate_dimensions(self, target_area, ratio):\n        import math\n        width = math.sqrt(target_area * ratio)\n        height = width / ratio\n        width = round(width / 32) * 32\n        height = round(height / 32) * 32\n        return width, height\n    \n    def crop_and_resize(self, image, target_height, target_width):\n        width, height = image.size\n        scale = max(target_width / width, target_height / height)\n        image = torchvision.transforms.functional.resize(\n            image,\n            (round(height*scale), round(width*scale)),\n            interpolation=torchvision.transforms.InterpolationMode.BILINEAR\n        )\n        image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))\n        return image\n\n    def edit_image_auto_resize(self, edit_image):\n        calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])\n        return self.crop_and_resize(edit_image, calculated_height, calculated_width)\n    \n    def process_image_ids(self, image_latents, scale=10):\n        t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]\n        t_coords = [t.view(-1) for t in t_coords]\n\n        image_latent_ids = []\n        for x, t in zip(image_latents, t_coords):\n            x = x.squeeze(0)\n            _, height, width = x.shape\n\n            x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))\n            image_latent_ids.append(x_ids)\n\n        image_latent_ids = torch.cat(image_latent_ids, dim=0)\n        image_latent_ids = image_latent_ids.unsqueeze(0)\n\n        return image_latent_ids\n\n    def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):\n        if edit_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        if isinstance(edit_image, Image.Image):\n            edit_image = [edit_image]\n        resized_edit_image, edit_latents = [], []\n        for image in edit_image:\n            # Preprocess\n            if edit_image_auto_resize is None or edit_image_auto_resize:\n                image = self.edit_image_auto_resize(image)\n            resized_edit_image.append(image)\n            # Encode\n            image = pipe.preprocess_image(image)\n            latents = pipe.vae.encode(image)\n            edit_latents.append(latents)\n        edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)\n        edit_latents = torch.concat([rearrange(latents, \"B C H W -> B (H W) C\") for latents in edit_latents], dim=1)\n        return {\"edit_latents\": edit_latents, \"edit_image_ids\": edit_image_ids}\n\n\nclass Flux2Unit_ImageIDs(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\"),\n            output_params=(\"image_ids\",),\n        )\n\n    def prepare_latent_ids(self, height, width):\n        t = torch.arange(1)  # [0] - time dimension\n        h = torch.arange(height)\n        w = torch.arange(width)\n        l = torch.arange(1)  # [0] - layer dimension\n\n        # Create position IDs: (H*W, 4)\n        latent_ids = torch.cartesian_prod(t, h, w, l)\n\n        # Expand to batch: (B, H*W, 4)\n        latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1)\n\n        return latent_ids\n\n    def process(self, pipe: Flux2ImagePipeline, height, width):\n        image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device)\n        return {\"image_ids\": image_ids}\n\n\ndef model_fn_flux2(\n    dit: Flux2DiT,\n    latents=None,\n    timestep=None,\n    embedded_guidance=None,\n    prompt_embeds=None,\n    text_ids=None,\n    image_ids=None,\n    edit_latents=None,\n    edit_image_ids=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    image_seq_len = latents.shape[1]\n    if edit_latents is not None:\n        image_seq_len = latents.shape[1]\n        latents = torch.concat([latents, edit_latents], dim=1)\n        image_ids = torch.concat([image_ids, edit_image_ids], dim=1)\n    embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)\n    model_output = dit(\n        hidden_states=latents,\n        timestep=timestep / 1000,\n        guidance=embedded_guidance,\n        encoder_hidden_states=prompt_embeds,\n        txt_ids=text_ids,\n        img_ids=image_ids,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    model_output = model_output[:, :image_seq_len]\n    return model_output\n"
  },
  {
    "path": "diffsynth/pipelines/flux_image.py",
    "content": "import torch, math\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat\nimport numpy as np\nfrom transformers import CLIPTokenizer, T5TokenizerFast\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput\nfrom ..utils.lora.flux import FluxLoRALoader\n\nfrom ..models.flux_dit import FluxDiT\nfrom ..models.flux_text_encoder_clip import FluxTextEncoderClip\nfrom ..models.flux_text_encoder_t5 import FluxTextEncoderT5\nfrom ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder\nfrom ..models.flux_value_control import MultiValueEncoder\nfrom ..models.step1x_text_encoder import Step1xEditEmbedder\nfrom ..core.vram.layers import AutoWrappedLinear\n\nclass MultiControlNet(torch.nn.Module):\n    def __init__(self, models: list[torch.nn.Module]):\n        super().__init__()\n        if not isinstance(models, list):\n            models = [models]\n        self.models = torch.nn.ModuleList(models)\n        \n    def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):\n        model = self.models[controlnet_input.controlnet_id]\n        res_stack, single_res_stack = model(\n            controlnet_conditioning=conditioning,\n            processor_id=controlnet_input.processor_id,\n            **kwargs\n        )\n        res_stack = [res * controlnet_input.scale for res in res_stack]\n        single_res_stack = [res * controlnet_input.scale for res in single_res_stack]\n        return res_stack, single_res_stack\n\n    def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):\n        res_stack, single_res_stack = None, None\n        for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):\n            progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)\n            if progress > controlnet_input.start or progress < controlnet_input.end:\n                continue\n            res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)\n            if res_stack is None:\n                res_stack = res_stack_\n                single_res_stack = single_res_stack_\n            else:\n                res_stack = [i + j for i, j in zip(res_stack, res_stack_)]\n                single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]\n        return res_stack, single_res_stack\n\n\nclass FluxImagePipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.1\")\n        self.tokenizer_1: CLIPTokenizer = None\n        self.tokenizer_2: T5TokenizerFast = None\n        self.text_encoder_1: FluxTextEncoderClip = None\n        self.text_encoder_2: FluxTextEncoderT5 = None\n        self.dit: FluxDiT = None\n        self.vae_decoder: FluxVAEDecoder = None\n        self.vae_encoder: FluxVAEEncoder = None\n        self.controlnet = None\n        self.ipadapter = None\n        self.ipadapter_image_encoder = None\n        self.qwenvl = None\n        self.step1x_connector = None\n        self.nexus_gen = None\n        self.nexus_gen_generation_adapter = None\n        self.nexus_gen_editing_adapter = None\n        self.value_controller = None\n        self.infinityou_processor = None\n        self.image_proj_model = None\n        self.lora_patcher = None\n        self.lora_encoder = None\n        self.in_iteration_models = (\"dit\", \"step1x_connector\", \"controlnet\", \"lora_patcher\")\n        self.units = [\n            FluxImageUnit_ShapeChecker(),\n            FluxImageUnit_NoiseInitializer(),\n            FluxImageUnit_PromptEmbedder(),\n            FluxImageUnit_InputImageEmbedder(),\n            FluxImageUnit_ImageIDs(),\n            FluxImageUnit_EmbeddedGuidanceEmbedder(),\n            FluxImageUnit_Kontext(),\n            FluxImageUnit_InfiniteYou(),\n            FluxImageUnit_ControlNet(),\n            FluxImageUnit_IPAdapter(),\n            FluxImageUnit_EntityControl(),\n            FluxImageUnit_NexusGen(),\n            FluxImageUnit_TeaCache(),\n            FluxImageUnit_Flex(),\n            FluxImageUnit_Step1x(),\n            FluxImageUnit_ValueControl(),\n            FluxImageUnit_LoRAEncode(),\n        ]\n        self.model_fn = model_fn_flux_image\n        self.lora_loader = FluxLoRALoader\n\n    def enable_lora_merger(self):\n        if not (hasattr(self.dit, \"vram_management_enabled\") and getattr(self.dit, \"vram_management_enabled\")):\n            raise ValueError(\"DiT VRAM management is not enabled.\")\n        if self.lora_patcher is not None:\n            for name, module in self.dit.named_modules():\n                if isinstance(module, AutoWrappedLinear):\n                    merger_name = name.replace(\".\", \"___\")\n                    if merger_name in self.lora_patcher.model_dict:\n                        module.lora_merger = self.lora_patcher.model_dict[merger_name]\n\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_1_config: ModelConfig = ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"tokenizer/\"),\n        tokenizer_2_config: ModelConfig = ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"tokenizer_2/\"),\n        nexus_gen_processor_config: ModelConfig = ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"processor/\"),\n        step1x_processor_config: ModelConfig = ModelConfig(model_id=\"Qwen/Qwen2.5-VL-7B-Instruct\", origin_file_pattern=\"\"),\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder_1 = model_pool.fetch_model(\"flux_text_encoder_clip\")\n        pipe.text_encoder_2 = model_pool.fetch_model(\"flux_text_encoder_t5\")\n        pipe.dit = model_pool.fetch_model(\"flux_dit\")\n        pipe.vae_encoder = model_pool.fetch_model(\"flux_vae_encoder\")\n        pipe.vae_decoder = model_pool.fetch_model(\"flux_vae_decoder\")\n        if tokenizer_1_config is not None:\n            tokenizer_1_config.download_if_necessary()\n            pipe.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_config.path)\n        if tokenizer_2_config is not None:\n            tokenizer_2_config.download_if_necessary()\n            pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path)\n        \n        value_controllers = model_pool.fetch_model(\"flux_value_controller\")\n        if value_controllers is not None:\n            pipe.value_controller = MultiValueEncoder(value_controllers)\n            if hasattr(pipe.value_controller.encoders[0], \"vram_management_enabled\"):\n                pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled\n        controlnets = model_pool.fetch_model(\"flux_controlnet\")\n        if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets)\n        pipe.ipadapter = model_pool.fetch_model(\"flux_ipadapter\")\n        pipe.ipadapter_image_encoder = model_pool.fetch_model(\"siglip_vision_model\")\n        qwenvl = model_pool.fetch_model(\"qwen_image_text_encoder\")\n        if qwenvl is not None:\n            from transformers import AutoProcessor\n            step1x_processor_config.download_if_necessary()\n            processor = AutoProcessor.from_pretrained(step1x_processor_config.path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28)\n            pipe.qwenvl = Step1xEditEmbedder(qwenvl, processor)\n        pipe.step1x_connector = model_pool.fetch_model(\"step1x_connector\")\n        pipe.image_proj_model = model_pool.fetch_model(\"infiniteyou_image_projector\")\n        if pipe.image_proj_model is not None:\n            pipe.infinityou_processor = InfinitYou(device=device)\n        pipe.lora_patcher = model_pool.fetch_model(\"flux_lora_patcher\")\n        pipe.lora_encoder = model_pool.fetch_model(\"flux_lora_encoder\")\n        pipe.nexus_gen = model_pool.fetch_model(\"nexus_gen_llm\")\n        pipe.nexus_gen_generation_adapter = model_pool.fetch_model(\"nexus_gen_generation_adapter\")\n        pipe.nexus_gen_editing_adapter = model_pool.fetch_model(\"nexus_gen_editing_adapter\")\n        if pipe.nexus_gen is not None:\n            nexus_gen_processor_config.download_if_necessary()\n            pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        embedded_guidance: float = 3.5,\n        t5_sequence_length: int = 512,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Scheduler\n        sigma_shift: float = None,\n        # Steps\n        num_inference_steps: int = 30,\n        # local prompts\n        multidiffusion_prompts=(),\n        multidiffusion_masks=(),\n        multidiffusion_scales=(),\n        # Kontext\n        kontext_images: Union[list[Image.Image], Image.Image] = None,\n        # ControlNet\n        controlnet_inputs: list[ControlNetInput] = None,\n        # IP-Adapter\n        ipadapter_images: Union[list[Image.Image], Image.Image] = None,\n        ipadapter_scale: float = 1.0,\n        # EliGen\n        eligen_entity_prompts: list[str] = None,\n        eligen_entity_masks: list[Image.Image] = None,\n        eligen_enable_on_negative: bool = False,\n        eligen_enable_inpaint: bool = False,\n        # InfiniteYou\n        infinityou_id_image: Image.Image = None,\n        infinityou_guidance: float = 1.0,\n        # Flex\n        flex_inpaint_image: Image.Image = None,\n        flex_inpaint_mask: Image.Image = None,\n        flex_control_image: Image.Image = None,\n        flex_control_strength: float = 0.5,\n        flex_control_stop: float = 0.5,\n        # Value Controller\n        value_controller_inputs: Union[list[float], float] = None,\n        # Step1x\n        step1x_reference_image: Image.Image = None,\n        # NexusGen\n        nexus_gen_reference_image: Image.Image = None,\n        # LoRA Encoder\n        lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,\n        lora_encoder_scale: float = 1.0,\n        # TeaCache\n        tea_cache_l1_thresh: float = None,\n        # Tile\n        tiled: bool = False,\n        tile_size: int = 128,\n        tile_stride: int = 64,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)\n        \n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale, \"embedded_guidance\": embedded_guidance, \"t5_sequence_length\": t5_sequence_length,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"sigma_shift\": sigma_shift, \"num_inference_steps\": num_inference_steps,\n            \"multidiffusion_prompts\": multidiffusion_prompts, \"multidiffusion_masks\": multidiffusion_masks, \"multidiffusion_scales\": multidiffusion_scales,\n            \"kontext_images\": kontext_images,\n            \"controlnet_inputs\": controlnet_inputs,\n            \"ipadapter_images\": ipadapter_images, \"ipadapter_scale\": ipadapter_scale,\n            \"eligen_entity_prompts\": eligen_entity_prompts, \"eligen_entity_masks\": eligen_entity_masks, \"eligen_enable_on_negative\": eligen_enable_on_negative, \"eligen_enable_inpaint\": eligen_enable_inpaint,\n            \"infinityou_id_image\": infinityou_id_image, \"infinityou_guidance\": infinityou_guidance,\n            \"flex_inpaint_image\": flex_inpaint_image, \"flex_inpaint_mask\": flex_inpaint_mask, \"flex_control_image\": flex_control_image, \"flex_control_strength\": flex_control_strength, \"flex_control_stop\": flex_control_stop,\n            \"value_controller_inputs\": value_controller_inputs,\n            \"step1x_reference_image\": step1x_reference_image,\n            \"nexus_gen_reference_image\": nexus_gen_reference_image,\n            \"lora_encoder_inputs\": lora_encoder_inputs, \"lora_encoder_scale\": lora_encoder_scale,\n            \"tea_cache_l1_thresh\": tea_cache_l1_thresh,\n            \"tiled\": tiled, \"tile_size\": tile_size, \"tile_stride\": tile_stride,\n            \"progress_bar_cmd\": progress_bar_cmd,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae_decoder'])\n        image = self.vae_decoder(inputs_shared[\"latents\"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass FluxImageUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"height\", \"width\"), output_params=(\"height\", \"width\"))\n\n    def process(self, pipe: FluxImagePipeline, height, width):\n        height, width = pipe.check_resize_height_width(height, width)\n        return {\"height\": height, \"width\": width}\n\n\n\nclass FluxImageUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"height\", \"width\", \"seed\", \"rand_device\"), output_params=(\"noise\",))\n\n    def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)\n        return {\"noise\": noise}\n\n\n\nclass FluxImageUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae_encoder'])\n        image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n        input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": None}\n\n\n\nclass FluxImageUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\", \"positive\": \"positive\"},\n            input_params_nega={\"prompt\": \"negative_prompt\", \"positive\": \"positive\"},\n            input_params=(\"t5_sequence_length\",),\n            output_params=(\"prompt_emb\", \"pooled_prompt_emb\", \"text_ids\"),\n            onload_model_names=(\"text_encoder_1\", \"text_encoder_2\")\n        )\n    \n    def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):\n        input_ids = tokenizer(\n            prompt,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            max_length=max_length,\n            truncation=True\n        ).input_ids.to(device)\n        pooled_prompt_emb, _ = text_encoder(input_ids)\n        return pooled_prompt_emb\n    \n    def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):\n        input_ids = tokenizer(\n            prompt,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            max_length=max_length,\n            truncation=True,\n        ).input_ids.to(device)\n        prompt_emb = text_encoder(input_ids)\n        return prompt_emb\n\n    def encode_prompt(\n        self,\n        tokenizer_1,\n        tokenizer_2,\n        text_encoder_1,\n        text_encoder_2,\n        prompt,\n        positive=True,\n        device=get_device_type(),\n        t5_sequence_length=512,\n    ):\n        pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)\n        prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device)\n        text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)\n        return prompt_emb, pooled_prompt_emb, text_ids\n\n    def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict:\n        if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None:\n            prompt_emb, pooled_prompt_emb, text_ids = self.encode_prompt(\n                tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2,\n                text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2,\n                prompt=prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length,\n            )\n            return {\"prompt_emb\": prompt_emb, \"pooled_prompt_emb\": pooled_prompt_emb, \"text_ids\": text_ids}\n        else:\n            return {}\n\n\nclass FluxImageUnit_ImageIDs(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"latents\",), output_params=(\"image_ids\",))\n\n    def process(self, pipe: FluxImagePipeline, latents):\n        latent_image_ids = pipe.dit.prepare_image_ids(latents)\n        return {\"image_ids\": latent_image_ids}\n\n\n\nclass FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"embedded_guidance\", \"latents\"), output_params=(\"guidance\",))\n\n    def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):\n        guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)\n        return {\"guidance\": guidance}\n\n\n\nclass FluxImageUnit_Kontext(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"kontext_images\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"kontext_latents\", \"kontext_image_ids\"),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):\n        if kontext_images is None:\n            return {}\n        if not isinstance(kontext_images, list):\n            kontext_images = [kontext_images]\n            \n        kontext_latents = []\n        kontext_image_ids = []\n        for kontext_image in kontext_images:\n            kontext_image = pipe.preprocess_image(kontext_image)\n            kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n            image_ids = pipe.dit.prepare_image_ids(kontext_latent)\n            image_ids[..., 0] = 1\n            kontext_image_ids.append(image_ids)\n            kontext_latent = pipe.dit.patchify(kontext_latent)\n            kontext_latents.append(kontext_latent)\n        kontext_latents = torch.concat(kontext_latents, dim=1)\n        kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)\n        return {\"kontext_latents\": kontext_latents, \"kontext_image_ids\": kontext_image_ids}\n\n\n\nclass FluxImageUnit_ControlNet(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"controlnet_inputs\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"controlnet_conditionings\",),\n            onload_model_names=(\"vae_encoder\",)\n        )\n        \n    def apply_controlnet_mask_on_latents(self, pipe, latents, mask):\n        mask = (pipe.preprocess_image(mask) + 1) / 2\n        mask = mask.mean(dim=1, keepdim=True)\n        mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])\n        latents = torch.concat([latents, mask], dim=1)\n        return latents\n        \n    def apply_controlnet_mask_on_image(self, pipe, image, mask):\n        mask = mask.resize(image.size)\n        mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()\n        image = np.array(image)\n        image[mask > 0] = 0\n        image = Image.fromarray(image)\n        return image\n\n    def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):\n        if controlnet_inputs is None:\n            return {}\n        pipe.load_models_to_device(['vae_encoder'])\n        conditionings = []\n        for controlnet_input in controlnet_inputs:\n            image = controlnet_input.image\n            if controlnet_input.inpaint_mask is not None:\n                image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)\n\n            image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n            \n            if controlnet_input.inpaint_mask is not None:\n                image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)\n            conditionings.append(image)\n        return {\"controlnet_conditionings\": conditionings}\n\n\n\nclass FluxImageUnit_IPAdapter(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"ipadapter_images\", \"ipadapter_scale\"),\n            output_params=(\"ipadapter_kwargs_list\",),\n            onload_model_names=(\"ipadapter_image_encoder\", \"ipadapter\")\n        )\n\n    def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        ipadapter_images, ipadapter_scale = inputs_shared.get(\"ipadapter_images\", None), inputs_shared.get(\"ipadapter_scale\", 1.0)\n        if ipadapter_images is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        if not isinstance(ipadapter_images, list):\n            ipadapter_images = [ipadapter_images]\n\n        pipe.load_models_to_device(self.onload_model_names)\n        images = [image.convert(\"RGB\").resize((384, 384), resample=3) for image in ipadapter_images]\n        images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images]\n        ipadapter_images = torch.cat(images, dim=0)\n        ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output\n\n        inputs_posi.update({\"ipadapter_kwargs_list\": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)})\n        if inputs_shared.get(\"cfg_scale\", 1.0) != 1.0:\n            inputs_nega.update({\"ipadapter_kwargs_list\": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))})\n        return inputs_shared, inputs_posi, inputs_nega\n\n\n\nclass FluxImageUnit_EntityControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"eligen_entity_prompts\", \"eligen_entity_masks\", \"eligen_enable_on_negative\", \"width\", \"height\", \"t5_sequence_length\", \"cfg_scale\"),\n            output_params=(\"entity_prompt_emb\", \"entity_masks\"),\n            onload_model_names=(\"text_encoder_1\", \"text_encoder_2\")\n        )\n        \n    def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):\n        input_ids = tokenizer(\n            prompt,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            max_length=max_length,\n            truncation=True\n        ).input_ids.to(device)\n        pooled_prompt_emb, _ = text_encoder(input_ids)\n        return pooled_prompt_emb\n    \n    def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):\n        input_ids = tokenizer(\n            prompt,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            max_length=max_length,\n            truncation=True,\n        ).input_ids.to(device)\n        prompt_emb = text_encoder(input_ids)\n        return prompt_emb\n\n    def encode_prompt(\n        self,\n        tokenizer_1,\n        tokenizer_2,\n        text_encoder_1,\n        text_encoder_2,\n        prompt,\n        positive=True,\n        device=get_device_type(),\n        t5_sequence_length=512,\n    ):\n        pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)\n        prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device)\n        text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)\n        return prompt_emb, pooled_prompt_emb, text_ids\n\n    def preprocess_masks(self, pipe, masks, height, width, dim):\n        out_masks = []\n        for mask in masks:\n            mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0\n            mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)\n            out_masks.append(mask)\n        return out_masks\n\n    def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512):\n        entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)\n        entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w\n\n        prompt_emb, _, _ = self.encode_prompt(\n            tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2,\n            text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2,\n            prompt=entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length,\n        )\n        return prompt_emb.unsqueeze(0), entity_masks\n\n    def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale):\n        entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length)\n        if enable_eligen_on_negative and cfg_scale != 1.0:\n            entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)\n            entity_masks_nega = entity_masks_posi\n        else:\n            entity_prompt_emb_nega, entity_masks_nega = None, None\n        eligen_kwargs_posi = {\"entity_prompt_emb\": entity_prompt_emb_posi, \"entity_masks\": entity_masks_posi}\n        eligen_kwargs_nega = {\"entity_prompt_emb\": entity_prompt_emb_nega, \"entity_masks\": entity_masks_nega}\n        return eligen_kwargs_posi, eligen_kwargs_nega\n\n    def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        eligen_entity_prompts, eligen_entity_masks = inputs_shared.get(\"eligen_entity_prompts\", None), inputs_shared.get(\"eligen_entity_masks\", None)\n        if eligen_entity_prompts is None or eligen_entity_masks is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        pipe.load_models_to_device(self.onload_model_names)\n        eligen_enable_on_negative = inputs_shared.get(\"eligen_enable_on_negative\", False)\n        eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,\n            eligen_entity_prompts, eligen_entity_masks, inputs_shared[\"width\"], inputs_shared[\"height\"], \n            inputs_shared[\"t5_sequence_length\"], eligen_enable_on_negative, inputs_shared[\"cfg_scale\"])\n        inputs_posi.update(eligen_kwargs_posi)\n        if inputs_shared.get(\"cfg_scale\", 1.0) != 1.0:\n            inputs_nega.update(eligen_kwargs_nega)\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass FluxImageUnit_NexusGen(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"nexus_gen_reference_image\", \"prompt\", \"latents\"),\n            output_params=(\"prompt_emb\", \"text_ids\"),\n            onload_model_names=(\"nexus_gen\", \"nexus_gen_generation_adapter\", \"nexus_gen_editing_adapter\"),\n        )\n\n    def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        if pipe.nexus_gen is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        pipe.load_models_to_device(self.onload_model_names)\n        if inputs_shared.get(\"nexus_gen_reference_image\", None) is None:\n            assert pipe.nexus_gen_generation_adapter is not None, \"NexusGen requires a generation adapter to be set.\"\n            embed = pipe.nexus_gen(inputs_posi[\"prompt\"])[0].unsqueeze(0)\n            inputs_posi[\"prompt_emb\"] = pipe.nexus_gen_generation_adapter(embed)\n            inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)\n        else:\n            assert pipe.nexus_gen_editing_adapter is not None, \"NexusGen requires an editing adapter to be set.\"\n            embed, ref_embed, grids = pipe.nexus_gen(inputs_posi[\"prompt\"], inputs_shared[\"nexus_gen_reference_image\"])\n            embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)\n            ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)\n\n            inputs_posi[\"prompt_emb\"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)\n            inputs_posi[\"text_ids\"] = self.get_editing_text_ids(\n                inputs_shared[\"latents\"],\n                embeds_grid[0][1].item(), embeds_grid[0][2].item(),\n                ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),\n                )\n        return inputs_shared, inputs_posi, inputs_nega\n\n\n    def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):\n        # prepare text ids for target and reference embeddings\n        batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width\n        embed_ids = torch.zeros(height // 2, width // 2, 3)\n        scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width\n        embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height\n        embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width\n        embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)\n        embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)\n\n        batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width\n        ref_embed_ids = torch.zeros(height // 2, width // 2, 3)\n        scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width\n        ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0\n        ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height\n        ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width\n        ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)\n        ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)\n\n        text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)\n        return text_ids\n\n\nclass FluxImageUnit_Step1x(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"step1x_reference_image\", \"prompt\", \"negative_prompt\"),\n            output_params=(\"step1x_llm_embedding\", \"step1x_mask\", \"step1x_reference_latents\"),\n            onload_model_names=(\"qwenvl\",\"vae_encoder\")\n        )\n    \n    def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):\n        image = inputs_shared.get(\"step1x_reference_image\",None)\n        if image is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            prompt = inputs_posi[\"prompt\"]\n            nega_prompt = inputs_nega[\"negative_prompt\"]\n            captions = [prompt, nega_prompt]\n            ref_images = [image, image]\n            embs, masks = pipe.qwenvl(captions, ref_images)\n            image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            image = pipe.vae_encoder(image)\n            inputs_posi.update({\"step1x_llm_embedding\": embs[0:1], \"step1x_mask\": masks[0:1], \"step1x_reference_latents\": image})\n            if inputs_shared.get(\"cfg_scale\", 1) != 1:\n                inputs_nega.update({\"step1x_llm_embedding\": embs[1:2], \"step1x_mask\": masks[1:2], \"step1x_reference_latents\": image})\n            return inputs_shared, inputs_posi, inputs_nega\n\n            \nclass FluxImageUnit_TeaCache(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"num_inference_steps\",\"tea_cache_l1_thresh\"), output_params=(\"tea_cache\",))\n    \n    def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):\n        if tea_cache_l1_thresh is None:\n            return {}\n        else:\n            return {\"tea_cache\": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)}\n\nclass FluxImageUnit_Flex(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"latents\", \"flex_inpaint_image\", \"flex_inpaint_mask\", \"flex_control_image\", \"flex_control_strength\", \"flex_control_stop\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"flex_condition\", \"flex_uncondition\", \"flex_control_stop_timestep\"),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):\n        if pipe.dit.input_dim == 196:\n            if flex_control_stop is None:\n                flex_control_stop = 1\n            pipe.load_models_to_device(self.onload_model_names)\n            if flex_inpaint_image is None:\n                flex_inpaint_image = torch.zeros_like(latents)\n            else:\n                flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n                flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n            if flex_inpaint_mask is None:\n                flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]\n            else:\n                flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))\n                flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype)\n                flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2\n            flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)\n            if flex_control_image is None:\n                flex_control_image = torch.zeros_like(latents)\n            else:\n                flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n                flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength\n            flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)\n            flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)\n            flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))]\n            return {\"flex_condition\": flex_condition, \"flex_uncondition\": flex_uncondition, \"flex_control_stop_timestep\": flex_control_stop_timestep}\n        else:\n            return {}\n\n\n\nclass FluxImageUnit_InfiniteYou(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"infinityou_id_image\", \"infinityou_guidance\"),\n            output_params=(\"id_emb\", \"infinityou_guidance\"),\n            onload_model_names=(\"infinityou_processor\",)\n        )\n\n    def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):\n        pipe.load_models_to_device(\"infinityou_processor\")\n        if infinityou_id_image is not None:\n            return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)\n        else:\n            return {}\n\n\n\nclass FluxImageUnit_ValueControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt_emb\": \"prompt_emb\", \"text_ids\": \"text_ids\"},\n            input_params_nega={\"prompt_emb\": \"prompt_emb\", \"text_ids\": \"text_ids\"},\n            input_params=(\"value_controller_inputs\",),\n            output_params=(\"prompt_emb\", \"text_ids\"),\n            onload_model_names=(\"value_controller\",)\n        )\n        \n    def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):\n        prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)\n        extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)\n        text_ids = torch.concat([text_ids, extra_text_ids], dim=1)\n        return prompt_emb, text_ids\n\n    def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):\n        if value_controller_inputs is None:\n            return {}\n        if not isinstance(value_controller_inputs, list):\n            value_controller_inputs = [value_controller_inputs]\n        value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)\n        pipe.load_models_to_device([\"value_controller\"])\n        value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)\n        value_emb = value_emb.unsqueeze(0)\n        prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)\n        return {\"prompt_emb\": prompt_emb, \"text_ids\": text_ids}\n\n\n\nclass InfinitYou(torch.nn.Module):\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__()\n        from facexlib.recognition import init_recognition_model\n        from insightface.app import FaceAnalysis\n        self.device = device\n        self.torch_dtype = torch_dtype\n        insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface'\n        self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n        self.app_640.prepare(ctx_id=0, det_size=(640, 640))\n        self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n        self.app_320.prepare(ctx_id=0, det_size=(320, 320))\n        self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n        self.app_160.prepare(ctx_id=0, det_size=(160, 160))\n        self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)\n\n    def _detect_face(self, id_image_cv2):\n        face_info = self.app_640.get(id_image_cv2)\n        if len(face_info) > 0:\n            return face_info\n        face_info = self.app_320.get(id_image_cv2)\n        if len(face_info) > 0:\n            return face_info\n        face_info = self.app_160.get(id_image_cv2)\n        return face_info\n\n    def extract_arcface_bgr_embedding(self, in_image, landmark, device):\n        from insightface.utils import face_align\n        arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)\n        arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.\n        arc_face_image = 2 * arc_face_image - 1\n        arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)\n        face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized\n        return face_emb\n\n    def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):\n        import cv2\n        if id_image is None:\n            return {'id_emb': None}\n        id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)\n        face_info = self._detect_face(id_image_cv2)\n        if len(face_info) == 0:\n            raise ValueError('No face detected in the input ID image')\n        landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face\n        id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)\n        id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))\n        infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)\n        return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}\n\n\n\nclass FluxImageUnit_LoRAEncode(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"lora_encoder_inputs\", \"lora_encoder_scale\"),\n            output_params=(\"prompt_emb\", \"text_ids\"),\n            onload_model_names=(\"lora_encoder\",)\n        )\n        \n    def parse_lora_encoder_inputs(self, lora_encoder_inputs):\n        if not isinstance(lora_encoder_inputs, list):\n            lora_encoder_inputs = [lora_encoder_inputs]\n        lora_configs = []\n        for lora_encoder_input in lora_encoder_inputs:\n            if isinstance(lora_encoder_input, str):\n                lora_encoder_input = ModelConfig(path=lora_encoder_input)\n            lora_encoder_input.download_if_necessary()\n            lora_configs.append(lora_encoder_input)\n        return lora_configs\n        \n    def load_lora(self, lora_config, dtype, device):\n        loader = FluxLoRALoader(torch_dtype=dtype, device=device)\n        lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)\n        lora = loader.convert_state_dict(lora)\n        return lora\n    \n    def lora_embedding(self, pipe, lora_encoder_inputs):\n        lora_emb = []\n        for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):\n            lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)\n            lora_emb.append(pipe.lora_encoder(lora))\n        lora_emb = torch.concat(lora_emb, dim=1)\n        return lora_emb\n    \n    def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):\n        prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)\n        extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)\n        text_ids = torch.concat([text_ids, extra_text_ids], dim=1)\n        return prompt_emb, text_ids\n\n    def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        if inputs_shared.get(\"lora_encoder_inputs\", None) is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        \n        # Encode\n        pipe.load_models_to_device([\"lora_encoder\"])\n        lora_encoder_inputs = inputs_shared[\"lora_encoder_inputs\"]\n        lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)\n        \n        # Scale\n        lora_encoder_scale = inputs_shared.get(\"lora_encoder_scale\", None)\n        if lora_encoder_scale is not None:\n            lora_emb = lora_emb * lora_encoder_scale\n        \n        # Add to prompt embedding\n        inputs_posi[\"prompt_emb\"], inputs_posi[\"text_ids\"] = self.add_to_text_embedding(\n            inputs_posi[\"prompt_emb\"], inputs_posi[\"text_ids\"], lora_emb)\n        return inputs_shared, inputs_posi, inputs_nega\n\n\n\nclass TeaCache:\n    def __init__(self, num_inference_steps, rel_l1_thresh):\n        self.num_inference_steps = num_inference_steps\n        self.step = 0\n        self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = None\n        self.rel_l1_thresh = rel_l1_thresh\n        self.previous_residual = None\n        self.previous_hidden_states = None\n\n    def check(self, dit: FluxDiT, hidden_states, conditioning):\n        inp = hidden_states.clone()\n        temb_ = conditioning.clone()\n        modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)\n        if self.step == 0 or self.step == self.num_inference_steps - 1:\n            should_calc = True\n            self.accumulated_rel_l1_distance = 0\n        else: \n            coefficients = [4.98651651e+02, -2.83781631e+02,  5.58554382e+01, -3.82021401e+00, 2.64230861e-01]\n            rescale_func = np.poly1d(coefficients)\n            self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())\n            if self.accumulated_rel_l1_distance < self.rel_l1_thresh:\n                should_calc = False\n            else:\n                should_calc = True\n                self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = modulated_inp \n        self.step += 1\n        if self.step == self.num_inference_steps:\n            self.step = 0\n        if should_calc:\n            self.previous_hidden_states = hidden_states.clone()\n        return not should_calc\n    \n    def store(self, hidden_states):\n        self.previous_residual = hidden_states - self.previous_hidden_states\n        self.previous_hidden_states = None\n\n    def update(self, hidden_states):\n        hidden_states = hidden_states + self.previous_residual\n        return hidden_states\n\n\nclass FastTileWorker:\n    def __init__(self):\n        pass\n\n\n    def build_mask(self, data, is_bound):\n        _, _, H, W = data.shape\n        h = repeat(torch.arange(H), \"H -> H W\", H=H, W=W)\n        w = repeat(torch.arange(W), \"W -> H W\", H=H, W=W)\n        border_width = (H + W) // 4\n        pad = torch.ones_like(h) * border_width\n        mask = torch.stack([\n            pad if is_bound[0] else h + 1,\n            pad if is_bound[1] else H - h,\n            pad if is_bound[2] else w + 1,\n            pad if is_bound[3] else W - w\n        ]).min(dim=0).values\n        mask = mask.clip(1, border_width)\n        mask = (mask / border_width).to(dtype=data.dtype, device=data.device)\n        mask = rearrange(mask, \"H W -> 1 H W\")\n        return mask\n\n\n    def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device=\"cpu\", tile_dtype=torch.float32, border_width=None):\n        # Prepare\n        B, C, H, W = model_input.shape\n        border_width = int(tile_stride*0.5) if border_width is None else border_width\n        weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)\n        values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)\n\n        # Split tasks\n        tasks = []\n        for h in range(0, H, tile_stride):\n            for w in range(0, W, tile_stride):\n                if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):\n                    continue\n                h_, w_ = h + tile_size, w + tile_size\n                if h_ > H: h, h_ = H - tile_size, H\n                if w_ > W: w, w_ = W - tile_size, W\n                tasks.append((h, h_, w, w_))\n        \n        # Run\n        for hl, hr, wl, wr in tasks:\n            # Forward\n            hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)\n\n            mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))\n            values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask\n            weight[:, :, hl:hr, wl:wr] += mask\n        values /= weight\n        return values\n\n    \ndef model_fn_flux_image(\n    dit: FluxDiT,\n    controlnet=None,\n    step1x_connector=None,\n    latents=None,\n    timestep=None,\n    prompt_emb=None,\n    pooled_prompt_emb=None,\n    guidance=None,\n    text_ids=None,\n    image_ids=None,\n    kontext_latents=None,\n    kontext_image_ids=None,\n    controlnet_inputs=None,\n    controlnet_conditionings=None,\n    tiled=False,\n    tile_size=128,\n    tile_stride=64,\n    entity_prompt_emb=None,\n    entity_masks=None,\n    ipadapter_kwargs_list={},\n    id_emb=None,\n    infinityou_guidance=None,\n    flex_condition=None,\n    flex_uncondition=None,\n    flex_control_stop_timestep=None,\n    step1x_llm_embedding=None,\n    step1x_mask=None,\n    step1x_reference_latents=None,\n    tea_cache: TeaCache = None,\n    progress_id=0,\n    num_inference_steps=1,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs\n):\n    if tiled:\n        def flux_forward_fn(hl, hr, wl, wr):\n            tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None\n            return model_fn_flux_image(\n                dit=dit,\n                controlnet=controlnet,\n                latents=latents[:, :, hl: hr, wl: wr],\n                timestep=timestep,\n                prompt_emb=prompt_emb,\n                pooled_prompt_emb=pooled_prompt_emb,\n                guidance=guidance,\n                text_ids=text_ids,\n                image_ids=None,\n                controlnet_inputs=controlnet_inputs,\n                controlnet_conditionings=tiled_controlnet_conditionings,\n                tiled=False,\n                **kwargs\n            )\n        return FastTileWorker().tiled_forward(\n            flux_forward_fn,\n            latents,\n            tile_size=tile_size,\n            tile_stride=tile_stride,\n            tile_device=latents.device,\n            tile_dtype=latents.dtype\n        )\n\n    hidden_states = latents\n\n    # ControlNet\n    if controlnet is not None and controlnet_conditionings is not None:\n        controlnet_extra_kwargs = {\n            \"hidden_states\": hidden_states,\n            \"timestep\": timestep,\n            \"prompt_emb\": prompt_emb,\n            \"pooled_prompt_emb\": pooled_prompt_emb,\n            \"guidance\": guidance,\n            \"text_ids\": text_ids,\n            \"image_ids\": image_ids,\n            \"controlnet_inputs\": controlnet_inputs,\n            \"tiled\": tiled,\n            \"tile_size\": tile_size,\n            \"tile_stride\": tile_stride,\n            \"progress_id\": progress_id,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        if id_emb is not None:\n            controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)\n            controlnet_extra_kwargs.update({\"prompt_emb\": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})\n        controlnet_res_stack, controlnet_single_res_stack = controlnet(\n            controlnet_conditionings, **controlnet_extra_kwargs\n        )\n        \n    # Flex\n    if flex_condition is not None:\n        if timestep.tolist()[0] >= flex_control_stop_timestep:\n            hidden_states = torch.concat([hidden_states, flex_condition], dim=1)\n        else:\n            hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)\n            \n    # Step1x\n    if step1x_llm_embedding is not None:\n        prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)\n        text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)\n\n    if image_ids is None:\n        image_ids = dit.prepare_image_ids(hidden_states)\n    \n    conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)\n    if dit.guidance_embedder is not None:\n        guidance = guidance * 1000\n        conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)\n\n    height, width = hidden_states.shape[-2:]\n    hidden_states = dit.patchify(hidden_states)\n    \n    # Kontext\n    if kontext_latents is not None:\n        image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)\n        hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)\n    \n    # Step1x\n    if step1x_reference_latents is not None:\n        step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)\n        step1x_reference_latents = dit.patchify(step1x_reference_latents)\n        image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)\n        hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)\n        \n    hidden_states = dit.x_embedder(hidden_states)\n\n    # EliGen\n    if entity_prompt_emb is not None and entity_masks is not None:\n        prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])\n    else:\n        prompt_emb = dit.context_embedder(prompt_emb)\n        image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))\n        attention_mask = None\n\n    # TeaCache\n    if tea_cache is not None:\n        tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)\n    else:\n        tea_cache_update = False\n\n    if tea_cache_update:\n        hidden_states = tea_cache.update(hidden_states)\n    else:\n        # Joint Blocks\n        for block_id, block in enumerate(dit.blocks):\n            hidden_states, prompt_emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                hidden_states,\n                prompt_emb,\n                conditioning,\n                image_rotary_emb,\n                attention_mask,\n                ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),\n            )\n            # ControlNet\n            if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:\n                if kontext_latents is None:\n                    hidden_states = hidden_states + controlnet_res_stack[block_id]\n                else:\n                    hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]\n\n        # Single Blocks\n        hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)\n        num_joint_blocks = len(dit.blocks)\n        for block_id, block in enumerate(dit.single_blocks):\n            hidden_states, prompt_emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                hidden_states,\n                prompt_emb,\n                conditioning,\n                image_rotary_emb,\n                attention_mask,\n                ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),\n            )\n            # ControlNet\n            if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:\n                if kontext_latents is None:\n                    hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]\n                else:\n                    hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]\n        hidden_states = hidden_states[:, prompt_emb.shape[1]:]\n\n        if tea_cache is not None:\n            tea_cache.store(hidden_states)\n\n    hidden_states = dit.final_norm_out(hidden_states, conditioning)\n    hidden_states = dit.final_proj_out(hidden_states)\n    \n    # Step1x\n    if step1x_reference_latents is not None:\n        hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]\n    \n    # Kontext\n    if kontext_latents is not None:\n        hidden_states = hidden_states[:, :-kontext_latents.shape[1]]\n\n    hidden_states = dit.unpatchify(hidden_states, height, width)\n\n    return hidden_states\n"
  },
  {
    "path": "diffsynth/pipelines/ltx2_audio_video.py",
    "content": "import torch, types\nimport numpy as np\nfrom PIL import Image\nfrom einops import repeat\nfrom typing import Optional, Union\nfrom einops import rearrange\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom typing import Optional\nfrom transformers import AutoImageProcessor, Gemma3Processor\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit\n\nfrom ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer\nfrom ..models.ltx2_dit import LTXModel\nfrom ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier\nfrom ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor\nfrom ..models.ltx2_upsampler import LTX2LatentUpsampler\nfrom ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS\nfrom ..utils.data.media_io_ltx2 import ltx2_preprocess\nfrom ..utils.data.audio import convert_to_stereo\n\n\nclass LTX2AudioVideoPipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device,\n            torch_dtype=torch_dtype,\n            height_division_factor=32,\n            width_division_factor=32,\n            time_division_factor=8,\n            time_division_remainder=1,\n        )\n        self.scheduler = FlowMatchScheduler(\"LTX-2\")\n        self.text_encoder: LTX2TextEncoder = None\n        self.tokenizer: LTXVGemmaTokenizer = None\n        self.processor: Gemma3Processor = None\n        self.text_encoder_post_modules: LTX2TextEncoderPostModules = None\n        self.dit: LTXModel = None\n        self.video_vae_encoder: LTX2VideoEncoder = None\n        self.video_vae_decoder: LTX2VideoDecoder = None\n        self.audio_vae_encoder: LTX2AudioEncoder = None\n        self.audio_vae_decoder: LTX2AudioDecoder = None\n        self.audio_vocoder: LTX2Vocoder = None\n        self.upsampler: LTX2LatentUpsampler = None\n\n        self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)\n        self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)\n        self.audio_processor: AudioProcessor = AudioProcessor()\n\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            LTX2AudioVideoUnit_PipelineChecker(),\n            LTX2AudioVideoUnit_ShapeChecker(),\n            LTX2AudioVideoUnit_PromptEmbedder(),\n            LTX2AudioVideoUnit_NoiseInitializer(),\n            LTX2AudioVideoUnit_VideoRetakeEmbedder(),\n            LTX2AudioVideoUnit_AudioRetakeEmbedder(),\n            LTX2AudioVideoUnit_InputAudioEmbedder(),\n            LTX2AudioVideoUnit_InputVideoEmbedder(),\n            LTX2AudioVideoUnit_InputImagesEmbedder(),\n            LTX2AudioVideoUnit_InContextVideoEmbedder(),\n        ]\n        self.stage2_units = [\n            LTX2AudioVideoUnit_SwitchStage2(),\n            LTX2AudioVideoUnit_NoiseInitializer(),\n            LTX2AudioVideoUnit_LatentsUpsampler(),\n            LTX2AudioVideoUnit_VideoRetakeEmbedder(),\n            LTX2AudioVideoUnit_AudioRetakeEmbedder(),\n            LTX2AudioVideoUnit_InputImagesEmbedder(),\n            LTX2AudioVideoUnit_SetScheduleStage2(),\n        ]\n        self.model_fn = model_fn_ltx2\n\n        self.default_negative_prompt = {\n            \"LTX-2\": (\n                \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n                \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n                \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n                \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n                \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n                \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n                \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n                \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n                \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n                \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n                \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n            ),\n            \"LTX-2.3\": (\n                \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n                \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n                \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n                \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n                \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n                \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n                \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n                \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n                \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n                \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n                \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n            ),\n        }\n\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n        stage2_lora_config: Optional[ModelConfig] = None,\n        stage2_lora_strength: float = 0.8,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n\n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"ltx2_text_encoder\")\n        tokenizer_config.download_if_necessary()\n        pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)\n        image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)\n        pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)\n\n        pipe.text_encoder_post_modules = model_pool.fetch_model(\"ltx2_text_encoder_post_modules\")\n        pipe.dit = model_pool.fetch_model(\"ltx2_dit\")\n        pipe.video_vae_encoder = model_pool.fetch_model(\"ltx2_video_vae_encoder\")\n        pipe.video_vae_decoder = model_pool.fetch_model(\"ltx2_video_vae_decoder\")\n        pipe.audio_vae_decoder = model_pool.fetch_model(\"ltx2_audio_vae_decoder\")\n        pipe.audio_vocoder = model_pool.fetch_model(\"ltx2_audio_vocoder\")\n        pipe.upsampler = model_pool.fetch_model(\"ltx2_latent_upsampler\")\n        pipe.audio_vae_encoder = model_pool.fetch_model(\"ltx2_audio_vae_encoder\")\n\n        # Stage 2\n        if stage2_lora_config is not None:\n            pipe.stage2_lora_config = stage2_lora_config\n            pipe.stage2_lora_strength = stage2_lora_strength\n\n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n\n    def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):\n        if skip_stage:\n            return inputs_shared, inputs_posi, inputs_nega\n        for unit in units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"video_latents\"] = self.step(self.scheduler, inputs_shared[\"video_latents\"], progress_id=progress_id, noise_pred=noise_pred_video,\n                                                       inpaint_mask=inputs_shared.get(\"denoise_mask_video\", None), input_latents=inputs_shared.get(\"input_latents_video\", None), **inputs_shared)\n            inputs_shared[\"audio_latents\"] = self.step(self.scheduler, inputs_shared[\"audio_latents\"], progress_id=progress_id, noise_pred=noise_pred_audio,\n                                                       inpaint_mask=inputs_shared.get(\"denoise_mask_audio\", None), input_latents=inputs_shared.get(\"input_latents_audio\", None), **inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: Optional[str] = \"\",\n        denoising_strength: float = 1.0,\n        # Image-to-video\n        input_images: Optional[list[Image.Image]] = None,\n        input_images_indexes: Optional[list[int]] = [0],\n        input_images_strength: Optional[float] = 1.0,\n        # In-Context Video Control\n        in_context_videos: Optional[list[list[Image.Image]]] = None,\n        in_context_downsample_factor: Optional[int] = 2,\n        # Video-to-video\n        retake_video: Optional[list[Image.Image]] = None,\n        retake_video_regions: Optional[list[tuple[float, float]]] = None,\n        # Audio-to-video\n        retake_audio: Optional[torch.Tensor] = None,\n        audio_sample_rate: Optional[int] = 48000,\n        retake_audio_regions: Optional[list[tuple[float, float]]] = None,\n        # Randomness\n        seed: Optional[int] = None,\n        rand_device: Optional[str] = \"cpu\",\n        # Shape\n        height: Optional[int] = 512,\n        width: Optional[int] = 768,\n        num_frames: Optional[int] = 121,\n        frame_rate: Optional[int] = 24,\n        # Classifier-free guidance\n        cfg_scale: Optional[float] = 3.0,\n        # Scheduler\n        num_inference_steps: Optional[int] = 30,\n        # VAE tiling\n        tiled: Optional[bool] = True,\n        tile_size_in_pixels: Optional[int] = 512,\n        tile_overlap_in_pixels: Optional[int] = 128,\n        tile_size_in_frames: Optional[int] = 128,\n        tile_overlap_in_frames: Optional[int] = 24,\n        # Special Pipelines\n        use_two_stage_pipeline: Optional[bool] = False,\n        stage2_spatial_upsample_factor: Optional[int] = 2,\n        clear_lora_before_state_two: Optional[bool] = False,\n        use_distilled_pipeline: Optional[bool] = False,\n        # progress_bar\n        progress_bar_cmd=tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case=\"ditilled_stage1\" if use_distilled_pipeline else None)\n        # Inputs\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"input_images\": input_images, \"input_images_indexes\": input_images_indexes, \"input_images_strength\": input_images_strength,\n            \"retake_video\": retake_video, \"retake_video_regions\": retake_video_regions,\n            \"retake_audio\": (retake_audio, audio_sample_rate) if retake_audio is not None else None, \"retake_audio_regions\": retake_audio_regions,\n            \"in_context_videos\": in_context_videos, \"in_context_downsample_factor\": in_context_downsample_factor,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"height\": height, \"width\": width, \"num_frames\": num_frames, \"frame_rate\": frame_rate,\n            \"cfg_scale\": cfg_scale,\n            \"tiled\": tiled, \"tile_size_in_pixels\": tile_size_in_pixels, \"tile_overlap_in_pixels\": tile_overlap_in_pixels,\n            \"tile_size_in_frames\": tile_size_in_frames, \"tile_overlap_in_frames\": tile_overlap_in_frames,\n            \"use_two_stage_pipeline\": use_two_stage_pipeline, \"use_distilled_pipeline\": use_distilled_pipeline, \"clear_lora_before_state_two\": clear_lora_before_state_two, \"stage2_spatial_upsample_factor\": stage2_spatial_upsample_factor,\n            \"video_patchifier\": self.video_patchifier, \"audio_patchifier\": self.audio_patchifier,\n        }\n        # Stage 1\n        inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)\n        # Stage 2\n        inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared[\"use_two_stage_pipeline\"])\n        # Decode\n        self.load_models_to_device(['video_vae_decoder'])\n        video = self.video_vae_decoder.decode(inputs_shared[\"video_latents\"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)\n        video = self.vae_output_to_video(video)\n        self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])\n        decoded_audio = self.audio_vae_decoder(inputs_shared[\"audio_latents\"])\n        decoded_audio = self.audio_vocoder(decoded_audio)\n        decoded_audio = self.output_audio_format_check(decoded_audio)\n        return video, decoded_audio\n\n\nclass LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"use_distilled_pipeline\", \"use_two_stage_pipeline\"),\n            output_params=(\"use_two_stage_pipeline\", \"cfg_scale\")\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if inputs_shared.get(\"use_distilled_pipeline\", False):\n            inputs_shared[\"use_two_stage_pipeline\"] = True\n            inputs_shared[\"cfg_scale\"] = 1.0\n            print(f\"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.\")\n        if inputs_shared.get(\"use_two_stage_pipeline\", False):\n            # distill pipeline also uses two-stage, but it does not needs lora\n            if not inputs_shared.get(\"use_distilled_pipeline\", False):\n                if not (hasattr(pipe, \"stage2_lora_config\") and pipe.stage2_lora_config is not None):\n                    raise ValueError(\"Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.\")\n            if not (hasattr(pipe, \"upsampler\") and pipe.upsampler is not None):\n                raise ValueError(\"Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.\")\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):\n    \"\"\"\n    For two-stage pipelines, the resolution must be divisible by 64.\n    For one-stage pipelines, the resolution must be divisible by 32.\n    This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.\n    \"\"\"\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\", \"use_two_stage_pipeline\", \"stage2_spatial_upsample_factor\"),\n            output_params=(\"height\", \"width\", \"num_frames\", \"stage_2_height\", \"stage_2_width\"),\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):\n        if use_two_stage_pipeline:\n            height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor\n            height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)\n            stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)\n        else:\n            stage_2_height, stage_2_width = None, None\n            height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)\n        return {\"height\": height, \"width\": width, \"num_frames\": num_frames, \"stage_2_height\": stage_2_height, \"stage_2_width\": stage_2_width}\n\n\nclass LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):\n\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"video_context\", \"audio_context\"),\n            onload_model_names=(\"text_encoder\", \"text_encoder_post_modules\"),\n        )\n    def _preprocess_text(\n        self,\n        pipe,\n        text: str,\n    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        token_pairs = pipe.tokenizer.tokenize_with_weights(text)[\"gemma\"]\n        input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)\n        attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)\n        outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n        return outputs.hidden_states, attention_mask\n    def encode_prompt(self, pipe, text, padding_side=\"left\"):\n        hidden_states, attention_mask = self._preprocess_text(pipe, text)\n        video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(\n            hidden_states, attention_mask, padding_side)\n        return video_encoding, audio_encoding, attention_mask\n\n    def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):\n        pipe.load_models_to_device(self.onload_model_names)\n        video_context, audio_context, _ = self.encode_prompt(pipe, prompt)\n        return {\"video_context\": video_context, \"audio_context\": audio_context}\n\n\nclass LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\", \"seed\", \"rand_device\", \"frame_rate\"),\n            output_params=(\"video_noise\", \"audio_noise\", \"video_positions\", \"audio_positions\", \"video_latent_shape\", \"audio_latent_shape\")\n        )\n\n    def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):\n        video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)\n        video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)\n        video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)\n\n        latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)\n        video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()\n        video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate\n        video_positions = video_positions.to(pipe.torch_dtype)\n\n        audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)\n        audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)\n        audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)\n        return {\n            \"video_noise\": video_noise,\n            \"audio_noise\": audio_noise,\n            \"video_positions\": video_positions,\n            \"audio_positions\": audio_positions,\n            \"video_latent_shape\": video_latent_shape,\n            \"audio_latent_shape\": audio_latent_shape\n        }\n\n    def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):\n        return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)\n\n\nclass LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_video\", \"video_noise\", \"tiled\", \"tile_size_in_pixels\", \"tile_overlap_in_pixels\"),\n            output_params=(\"video_latents\", \"input_latents\"),\n            onload_model_names=(\"video_vae_encoder\")\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):\n        if input_video is None or not pipe.scheduler.training:\n            return {\"video_latents\": video_noise}\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            input_video = pipe.preprocess_video(input_video)\n            input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)\n            return {\"video_latents\": input_latents, \"input_latents\": input_latents}\n\nclass LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_audio\", \"audio_noise\"),\n            output_params=(\"audio_latents\", \"audio_input_latents\", \"audio_positions\", \"audio_latent_shape\"),\n            onload_model_names=(\"audio_vae_encoder\",)\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):\n        if input_audio is None or not pipe.scheduler.training:\n            return {\"audio_latents\": audio_noise}\n        else:\n            input_audio, sample_rate = input_audio\n            input_audio = convert_to_stereo(input_audio)\n            pipe.load_models_to_device(self.onload_model_names)\n            input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)\n            audio_input_latents = pipe.audio_vae_encoder(input_audio)\n            audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)\n            audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)\n            return {\"audio_latents\": audio_input_latents, \"audio_input_latents\": audio_input_latents, \"audio_positions\": audio_positions, \"audio_latent_shape\": audio_latent_shape}\n\n\nclass LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"retake_video\", \"height\", \"width\", \"tiled\", \"tile_size_in_pixels\", \"tile_overlap_in_pixels\", \"video_positions\", \"retake_video_regions\"),\n            output_params=(\"input_latents_video\", \"denoise_mask_video\"),\n            onload_model_names=(\"video_vae_encoder\")\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None):\n        if retake_video is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        resized_video = [frame.resize((width, height)) for frame in retake_video]\n        input_video = pipe.preprocess_video(resized_video)\n        input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)\n\n        b, c, f, h, w = input_latents_video.shape\n        denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype)\n        if retake_video_regions is not None and len(retake_video_regions) > 0:\n            for start_time, end_time in retake_video_regions:\n                t_start, t_end = video_positions[0, 0].unbind(dim=-1)\n                in_region = (t_end >= start_time) & (t_start <= end_time)\n                in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w)\n                denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video)\n\n        return {\"input_latents_video\": input_latents_video, \"denoise_mask_video\": denoise_mask_video}\n\n\nclass LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):\n    \"\"\"\n    Functionality of audio2video, audio retaking.\n    \"\"\"\n    def __init__(self):\n        super().__init__(\n            input_params=(\"retake_audio\", \"seed\", \"rand_device\", \"retake_audio_regions\"),\n            output_params=(\"input_latents_audio\", \"audio_noise\", \"audio_positions\", \"audio_latent_shape\", \"denoise_mask_audio\"),\n            onload_model_names=(\"audio_vae_encoder\",)\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None):\n        if retake_audio is None:\n            return {}\n        else:\n            input_audio, sample_rate = retake_audio\n            input_audio = convert_to_stereo(input_audio)\n            pipe.load_models_to_device(self.onload_model_names)\n            input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)\n            input_latents_audio = pipe.audio_vae_encoder(input_audio)\n            audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape)\n            audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)\n            # Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch.\n            audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device)\n\n            b, c, t, f = input_latents_audio.shape\n            denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype)\n            if retake_audio_regions is not None and len(retake_audio_regions) > 0:\n                for start_time, end_time in retake_audio_regions:\n                    t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1]\n                    in_region = (t_end >= start_time) & (t_start <= end_time)\n                    in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1)\n                    denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio)\n\n            return {\n                \"input_latents_audio\": input_latents_audio,\n                \"denoise_mask_audio\": denoise_mask_audio,\n                \"audio_noise\": audio_noise,\n                \"audio_positions\": audio_positions,\n                \"audio_latent_shape\": audio_latent_shape,\n            }\n\n\nclass LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_images\", \"input_images_indexes\", \"input_images_strength\", \"video_latents\", \"height\", \"width\", \"frame_rate\", \"tiled\", \"tile_size_in_pixels\", \"tile_overlap_in_pixels\", \"input_latents_video\", \"denoise_mask_video\"),\n            output_params=(\"denoise_mask_video\", \"input_latents_video\", \"ref_frames_latents\", \"ref_frames_positions\"),\n            onload_model_names=(\"video_vae_encoder\")\n        )\n\n    def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):\n        image = ltx2_preprocess(np.array(input_image.resize((width, height))))\n        image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)\n        image = image / 127.5 - 1.0\n        image = repeat(image, f\"H W C -> B C F H W\", B=1, F=1)\n        latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)\n        return latents\n\n    def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None):\n        b, _, f, h, w = latents.shape\n        denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video\n        input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video\n        for idx, input_latent in zip(input_indexes, input_latents):\n            idx = min(max(1 + (idx-1) // 8, 0), f - 1)\n            input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)\n            input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent\n            denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength\n        return input_latents_video, denoise_mask\n\n    def process(\n        self,\n        pipe: LTX2AudioVideoPipeline,\n        video_latents,\n        input_images,\n        height,\n        width,\n        frame_rate,\n        tiled,\n        tile_size_in_pixels,\n        tile_overlap_in_pixels,\n        input_images_indexes=[0],\n        input_images_strength=1.0,\n        input_latents_video=None,\n        denoise_mask_video=None,\n    ):\n        if input_images is None or len(input_images) == 0:\n            return {}\n        else:\n            if len(input_images_indexes) != len(set(input_images_indexes)):\n                raise ValueError(\"Input images must have unique indexes.\")\n            pipe.load_models_to_device(self.onload_model_names)\n            frame_conditions = {\"input_latents_video\": None, \"denoise_mask_video\": None, \"ref_frames_latents\": [], \"ref_frames_positions\": []}\n            for img, index in zip(input_images, input_images_indexes):\n                latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)\n                # first_frame by replacing latents\n                if index == 0:\n                    input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(\n                        video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video)\n                    frame_conditions.update({\"input_latents_video\": input_latents_video, \"denoise_mask_video\": denoise_mask_video})\n                # other frames by adding reference latents\n                else:\n                    latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)\n                    video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()\n                    video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate\n                    video_positions = video_positions.to(pipe.torch_dtype)\n                    frame_conditions[\"ref_frames_latents\"].append(latents)\n                    frame_conditions[\"ref_frames_positions\"].append(video_positions)\n            if len(frame_conditions[\"ref_frames_latents\"]) == 0:\n                frame_conditions.update({\"ref_frames_latents\": None, \"ref_frames_positions\": None})\n            return frame_conditions\n\n\nclass LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"in_context_videos\", \"height\", \"width\", \"num_frames\", \"frame_rate\", \"in_context_downsample_factor\", \"tiled\", \"tile_size_in_pixels\", \"tile_overlap_in_pixels\"),\n            output_params=(\"in_context_video_latents\", \"in_context_video_positions\"),\n            onload_model_names=(\"video_vae_encoder\")\n        )\n\n    def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):\n        if in_context_video is None or len(in_context_video) == 0:\n            raise ValueError(\"In-context video is None or empty.\")\n        in_context_video = in_context_video[:num_frames]\n        expected_height = height // in_context_downsample_factor\n        expected_width = width // in_context_downsample_factor\n        current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)\n        h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)\n        if current_h != h or current_w != w:\n            in_context_video = [img.resize((w, h)) for img in in_context_video]\n        if current_f != f:\n            # pad black frames at the end\n            in_context_video = in_context_video + [Image.new(\"RGB\", (w, h), (0, 0, 0))] * (f - current_f)\n        return in_context_video\n\n    def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):\n        if in_context_videos is None or len(in_context_videos) == 0:\n            return {}\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            latents, positions = [], []\n            for in_context_video in in_context_videos:\n                in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)\n                in_context_video = pipe.preprocess_video(in_context_video)\n                in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)\n\n                latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)\n                video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()\n                video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate\n                video_positions[:, 1, ...] *= in_context_downsample_factor  # height axis\n                video_positions[:, 2, ...] *= in_context_downsample_factor  # width axis\n                video_positions = video_positions.to(pipe.torch_dtype)\n\n                latents.append(in_context_latents)\n                positions.append(video_positions)\n            latents = torch.cat(latents, dim=1)\n            positions = torch.cat(positions, dim=1)\n            return {\"in_context_video_latents\": latents, \"in_context_video_positions\": positions}\n\n\nclass LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):\n    \"\"\"\n    1. switch height and width to stage 2 resolution\n    2. clear in_context_video_latents and in_context_video_positions\n    3. switch stage 2 lora model\n    \"\"\"\n    def __init__(self):\n        super().__init__(\n            input_params=(\"stage_2_height\", \"stage_2_width\", \"clear_lora_before_state_two\", \"use_distilled_pipeline\"),\n            output_params=(\"height\", \"width\", \"in_context_video_latents\", \"in_context_video_positions\"),\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):\n        stage2_params = {}\n        stage2_params.update({\"height\": stage_2_height, \"width\": stage_2_width})\n        stage2_params.update({\"in_context_video_latents\": None, \"in_context_video_positions\": None})\n        stage2_params.update({\"input_latents_video\": None, \"denoise_mask_video\": None})\n        if clear_lora_before_state_two:\n            pipe.clear_lora()\n        if not use_distilled_pipeline:\n            pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)\n        return stage2_params\n\n\nclass LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"video_latents\", \"video_noise\", \"audio_latents\", \"audio_noise\"),\n            output_params=(\"video_latents\", \"audio_latents\"),\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):\n        pipe.scheduler.set_timesteps(special_case=\"stage2\")\n        video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])\n        audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])\n        return {\"video_latents\": video_latents, \"audio_latents\": audio_latents}\n\n\nclass LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"video_latents\",),\n            output_params=(\"video_latents\",),\n            onload_model_names=(\"upsampler\",),\n        )\n\n    def process(self, pipe: LTX2AudioVideoPipeline, video_latents):\n        if video_latents is None or pipe.upsampler is None:\n            raise ValueError(\"No upsampler or no video latents before stage 2.\")\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)\n            video_latents = pipe.upsampler(video_latents)\n            video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)\n            return {\"video_latents\": video_latents}\n\n\ndef model_fn_ltx2(\n    dit: LTXModel,\n    video_latents=None,\n    video_context=None,\n    video_positions=None,\n    video_patchifier=None,\n    audio_latents=None,\n    audio_context=None,\n    audio_positions=None,\n    audio_patchifier=None,\n    timestep=None,\n    # First Frame Conditioning\n    input_latents_video=None,\n    denoise_mask_video=None,\n    # Other Frames Conditioning\n    ref_frames_latents=None,\n    ref_frames_positions=None,\n    # In-Context Conditioning\n    in_context_video_latents=None,\n    in_context_video_positions=None,\n    # Audio Inputs\n    input_latents_audio=None,\n    denoise_mask_audio=None,\n    # Gradient Checkpointing\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    timestep = timestep.float() / 1000.\n\n    # patchify\n    b, c_v, f, h, w = video_latents.shape\n    video_latents = video_patchifier.patchify(video_latents)\n    seq_len_video = video_latents.shape[1]\n    video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)\n    # Frist frame conditioning by replacing the video latents\n    if input_latents_video is not None:\n        denoise_mask_video = video_patchifier.patchify(denoise_mask_video)\n        video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)\n        video_timesteps = denoise_mask_video * video_timesteps\n\n    # Reference conditioning by appending the reference video or frame latents\n    total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []\n    total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []\n    total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []\n    total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []\n    if len(total_ref_latents) > 0:\n        for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):\n            ref_frames_latent = video_patchifier.patchify(ref_frames_latent)\n            ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.\n            video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)\n            video_positions = torch.cat([video_positions, ref_frames_position], dim=2)\n            video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)\n\n    if audio_latents is not None:\n        _, c_a, _, mel_bins  = audio_latents.shape\n        audio_latents = audio_patchifier.patchify(audio_latents)\n        audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)\n    else:\n        audio_timesteps = None\n    if input_latents_audio is not None:\n        denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio)\n        audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio)\n        audio_timesteps = denoise_mask_audio * audio_timesteps\n\n    vx, ax = dit(\n        video_latents=video_latents,\n        video_positions=video_positions,\n        video_context=video_context,\n        video_timesteps=video_timesteps,\n        audio_latents=audio_latents,\n        audio_positions=audio_positions,\n        audio_context=audio_context,\n        audio_timesteps=audio_timesteps,\n        sigma=timestep,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n\n    vx = vx[:, :seq_len_video, ...]\n    # unpatchify\n    vx = video_patchifier.unpatchify_video(vx, f, h, w)\n    ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None\n    return vx, ax\n"
  },
  {
    "path": "diffsynth/pipelines/mova_audio_video.py",
    "content": "import sys\nimport torch, types\nfrom PIL import Image\nfrom typing import Optional, Union\nfrom einops import rearrange\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom typing import Optional\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit\n\nfrom ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm\nfrom ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer\nfrom ..models.wan_video_vae import WanVideoVAE\nfrom ..models.mova_audio_dit import MovaAudioDit\nfrom ..models.mova_audio_vae import DacVAE\nfrom ..models.mova_dual_tower_bridge import DualTowerConditionalBridge\nfrom ..utils.data.audio import convert_to_mono, resample_waveform\n\n\nclass MovaAudioVideoPipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1\n        )\n        self.scheduler = FlowMatchScheduler(\"Wan\")\n        self.tokenizer: HuggingfaceTokenizer = None\n        self.text_encoder: WanTextEncoder = None\n        self.video_dit: WanModel = None # high noise model\n        self.video_dit2: WanModel = None # low noise model\n        self.audio_dit: MovaAudioDit = None\n        self.dual_tower_bridge: DualTowerConditionalBridge = None\n        self.video_vae: WanVideoVAE = None\n        self.audio_vae: DacVAE = None\n\n        self.in_iteration_models = (\"video_dit\", \"audio_dit\", \"dual_tower_bridge\")\n        self.in_iteration_models_2 = (\"video_dit2\", \"audio_dit\", \"dual_tower_bridge\")\n\n        self.units = [\n            MovaAudioVideoUnit_ShapeChecker(),\n            MovaAudioVideoUnit_NoiseInitializer(),\n            MovaAudioVideoUnit_InputVideoEmbedder(),\n            MovaAudioVideoUnit_InputAudioEmbedder(),\n            MovaAudioVideoUnit_PromptEmbedder(),\n            MovaAudioVideoUnit_ImageEmbedderVAE(),\n            MovaAudioVideoUnit_UnifiedSequenceParallel(),\n        ]\n        self.model_fn = model_fn_mova_audio_video\n\n    def enable_usp(self):\n        from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward\n        for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:\n            block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)\n        self.sp_size = get_sequence_parallel_world_size()\n        self.use_unified_sequence_parallel = True\n\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n        use_usp: bool = False,\n        vram_limit: float = None,\n    ):\n        if use_usp:\n            from ..utils.xfuser import initialize_usp\n            initialize_usp(device)\n            import torch.distributed as dist\n            from ..core.device.npu_compatible_device import get_device_name\n            if dist.is_available() and dist.is_initialized():\n                device = get_device_name()\n        # Initialize pipeline\n        pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n\n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"wan_video_text_encoder\")\n        dit = model_pool.fetch_model(\"wan_video_dit\", index=2)\n        if isinstance(dit, list):\n            pipe.video_dit, pipe.video_dit2 = dit\n        else:\n            pipe.video_dit = dit\n        pipe.audio_dit = model_pool.fetch_model(\"mova_audio_dit\")\n        pipe.dual_tower_bridge = model_pool.fetch_model(\"mova_dual_tower_bridge\")\n        pipe.video_vae = model_pool.fetch_model(\"wan_video_vae\")\n        pipe.audio_vae = model_pool.fetch_model(\"mova_audio_vae\")\n        set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))\n\n        # Size division factor\n        if pipe.video_vae is not None:\n            pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2\n            pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2\n\n        # Initialize tokenizer and processor\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')\n\n        # Unified Sequence Parallel\n        if use_usp: pipe.enable_usp()\n\n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: Optional[str] = \"\",\n        # Image-to-video\n        input_image: Optional[Image.Image] = None,\n        # First-last-frame-to-video\n        end_image: Optional[Image.Image] = None,\n        # Video-to-video\n        denoising_strength: Optional[float] = 1.0,\n        # Randomness\n        seed: Optional[int] = None,\n        rand_device: Optional[str] = \"cpu\",\n        # Shape\n        height: Optional[int] = 352,\n        width: Optional[int] = 640,\n        num_frames: Optional[int] = 81,\n        frame_rate: Optional[int] = 24,\n        # Classifier-free guidance\n        cfg_scale: Optional[float] = 5.0,\n        # Boundary\n        switch_DiT_boundary: Optional[float] = 0.9,\n        # Scheduler\n        num_inference_steps: Optional[int] = 50,\n        sigma_shift: Optional[float] = 5.0,\n        # VAE tiling\n        tiled: Optional[bool] = True,\n        tile_size: Optional[tuple[int, int]] = (30, 52),\n        tile_stride: Optional[tuple[int, int]] = (15, 26),\n        # progress_bar\n        progress_bar_cmd=tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)\n\n        # Inputs\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"input_image\": input_image,\n            \"end_image\": end_image,\n            \"denoising_strength\": denoising_strength,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"height\": height, \"width\": width, \"num_frames\": num_frames, \"frame_rate\": frame_rate,\n            \"cfg_scale\": cfg_scale,\n            \"sigma_shift\": sigma_shift,\n            \"tiled\": tiled, \"tile_size\": tile_size, \"tile_stride\": tile_stride,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            # Switch DiT if necessary\n            if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models[\"video_dit\"] is self.video_dit2:\n                self.load_models_to_device(self.in_iteration_models_2)\n                models[\"video_dit\"] = self.video_dit2\n            # Timestep\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            # Scheduler\n            inputs_shared[\"video_latents\"] = self.step(self.scheduler, inputs_shared[\"video_latents\"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)\n            inputs_shared[\"audio_latents\"] = self.step(self.scheduler, inputs_shared[\"audio_latents\"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)\n\n        # Decode\n        self.load_models_to_device(['video_vae'])\n        video = self.video_vae.decode(inputs_shared[\"video_latents\"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        video = self.vae_output_to_video(video)\n        self.load_models_to_device([\"audio_vae\"])\n        audio = self.audio_vae.decode(inputs_shared[\"audio_latents\"])\n        audio = self.output_audio_format_check(audio)\n        self.load_models_to_device([])\n        return video, audio\n\n\nclass MovaAudioVideoUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\"),\n            output_params=(\"height\", \"width\", \"num_frames\"),\n        )\n\n    def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):\n        height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)\n        return {\"height\": height, \"width\": width, \"num_frames\": num_frames}\n\n\nclass MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\", \"seed\", \"rand_device\", \"frame_rate\"),\n            output_params=(\"video_noise\", \"audio_noise\")\n        )\n\n    def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):\n        length = (num_frames - 1) // 4 + 1\n        video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)\n        video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)\n\n        audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1\n        audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)\n        audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)\n        return {\"video_noise\": video_noise, \"audio_noise\": audio_noise}\n\n\nclass MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_video\", \"video_noise\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"video_latents\", \"input_latents\"),\n            onload_model_names=(\"video_vae\",)\n        )\n\n    def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):\n        if input_video is None or not pipe.scheduler.training:\n            return {\"video_latents\": video_noise}\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            input_video = pipe.preprocess_video(input_video)\n            input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n            return {\"input_latents\": input_latents}\n\n\nclass MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_audio\", \"audio_noise\"),\n            output_params=(\"audio_latents\", \"audio_input_latents\"),\n            onload_model_names=(\"audio_vae\",)\n        )\n\n    def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):\n        if input_audio is None or not pipe.scheduler.training:\n            return {\"audio_latents\": audio_noise}\n        else:\n            pipe.load_models_to_device(self.onload_model_names)\n            input_audio, sample_rate = input_audio\n            input_audio = convert_to_mono(input_audio)\n            input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)\n            input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)\n            z, _, _, _, _ = pipe.audio_vae.encode(input_audio)\n            return {\"audio_input_latents\": z.mode()}\n\n\nclass MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"context\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):\n        ids, mask = pipe.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=512,\n            truncation=True,\n            add_special_tokens=True,\n            return_mask=True,\n            return_tensors=\"pt\",\n        )\n        ids = ids.to(pipe.device)\n        mask = mask.to(pipe.device)\n        seq_lens = mask.gt(0).sum(dim=1).long()\n        prompt_emb = pipe.text_encoder(ids, mask)\n        for i, v in enumerate(seq_lens):\n            prompt_emb[:, v:] = 0\n        return prompt_emb\n\n    def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:\n        pipe.load_models_to_device(self.onload_model_names)\n        prompt_emb = self.encode_prompt(pipe, prompt)\n        return {\"context\": prompt_emb}\n\n\nclass MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"end_image\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"y\",),\n            onload_model_names=(\"video_vae\",)\n        )\n\n    def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):\n        if input_image is None or not pipe.video_dit.require_vae_embedding:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n\n        image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)\n        msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)\n        msk[:, 1:] = 0\n        if end_image is not None:\n            end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)\n            vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)\n            msk[:, -1:] = 1\n        else:\n            vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)\n\n        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n        msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)\n        msk = msk.transpose(1, 2)[0]\n\n        y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        y = torch.concat([msk, y])\n        y = y.unsqueeze(0)\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"y\": y}\n\n\nclass MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(), output_params=(\"use_unified_sequence_parallel\",))\n\n    def process(self, pipe: MovaAudioVideoPipeline):\n        if hasattr(pipe, \"use_unified_sequence_parallel\") and pipe.use_unified_sequence_parallel:\n            return {\"use_unified_sequence_parallel\": True}\n        return {\"use_unified_sequence_parallel\": False}\n\n\ndef model_fn_mova_audio_video(\n    video_dit: WanModel,\n    audio_dit: MovaAudioDit,\n    dual_tower_bridge: DualTowerConditionalBridge,\n    video_latents: torch.Tensor = None,\n    audio_latents: torch.Tensor = None,\n    timestep: torch.Tensor = None,\n    context: torch.Tensor = None,\n    y: Optional[torch.Tensor] = None,\n    frame_rate: Optional[int] = 24,\n    use_unified_sequence_parallel: bool = False,\n    use_gradient_checkpointing: bool = False,\n    use_gradient_checkpointing_offload: bool = False,\n    **kwargs,\n):\n    video_x, audio_x = video_latents, audio_latents\n    # First-Last Frame\n    if y is not None:\n        video_x = torch.cat([video_x, y], dim=1)\n\n    # Timestep\n    video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))\n    video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))\n    audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))\n    audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))\n\n    # Context\n    video_context = video_dit.text_embedding(context)\n    audio_context = audio_dit.text_embedding(context)\n\n    # Patchify\n    video_x = video_dit.patch_embedding(video_x)\n    f_v, h, w = video_x.shape[2:]\n    video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()\n    seq_len_video = video_x.shape[1]\n\n    audio_x = audio_dit.patch_embedding(audio_x)\n    f_a = audio_x.shape[2]\n    audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()\n    seq_len_audio = audio_x.shape[1]\n\n    # Freqs\n    video_freqs = torch.cat([\n        video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),\n        video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),\n        video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)\n    ], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)\n    audio_freqs = torch.cat([\n        audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),\n        audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),\n        audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),\n    ], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)\n\n    video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(\n        video_fps=frame_rate,\n        grid_size=(f_v, h, w),\n        audio_steps=audio_x.shape[1],\n        device=video_x.device,\n        dtype=video_x.dtype,\n    )\n    # usp func\n    if use_unified_sequence_parallel:\n        from ..utils.xfuser import get_current_chunk, gather_all_chunks\n    else:\n        get_current_chunk = lambda x, dim=1: x\n        gather_all_chunks = lambda x, seq_len, dim=1: x\n    # Forward blocks\n    for block_id in range(len(audio_dit.blocks)):\n        if dual_tower_bridge.should_interact(block_id, \"a2v\"):\n            video_x, audio_x = dual_tower_bridge(\n                block_id,\n                video_x,\n                audio_x,\n                x_freqs=video_rope,\n                y_freqs=audio_rope,\n                condition_scale=1.0,\n                video_grid_size=(f_v, h, w),\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            )\n        video_x = get_current_chunk(video_x, dim=1)\n        video_x = gradient_checkpoint_forward(\n            video_dit.blocks[block_id],\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            video_x, video_context, video_t_mod, video_freqs\n        )\n        video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)\n        audio_x = get_current_chunk(audio_x, dim=1)\n        audio_x = gradient_checkpoint_forward(\n            audio_dit.blocks[block_id],\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            audio_x, audio_context, audio_t_mod, audio_freqs\n        )\n        audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)\n\n    video_x = get_current_chunk(video_x, dim=1)\n    for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):\n        video_x = gradient_checkpoint_forward(\n            video_dit.blocks[block_id],\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            video_x, video_context, video_t_mod, video_freqs\n        )\n    video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)\n\n    # Head\n    video_x = video_dit.head(video_x, video_t)\n    video_x = video_dit.unpatchify(video_x, (f_v, h, w))\n\n    audio_x = audio_dit.head(audio_x, audio_t)\n    audio_x = audio_dit.unpatchify(audio_x, (f_a,))\n    return video_x, audio_x\n"
  },
  {
    "path": "diffsynth/pipelines/qwen_image.py",
    "content": "import torch, math\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange\nimport numpy as np\nfrom math import prod\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput\nfrom ..utils.lora.merge import merge_lora\n\nfrom ..models.qwen_image_dit import QwenImageDiT\nfrom ..models.qwen_image_text_encoder import QwenImageTextEncoder\nfrom ..models.qwen_image_vae import QwenImageVAE\nfrom ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet\nfrom ..models.siglip2_image_encoder import Siglip2ImageEncoder\nfrom ..models.dinov3_image_encoder import DINOv3ImageEncoder\nfrom ..models.qwen_image_image2lora import QwenImageImage2LoRAModel\n\n\nclass QwenImagePipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        from transformers import Qwen2Tokenizer, Qwen2VLProcessor\n        \n        self.scheduler = FlowMatchScheduler(\"Qwen-Image\")\n        self.text_encoder: QwenImageTextEncoder = None\n        self.dit: QwenImageDiT = None\n        self.vae: QwenImageVAE = None\n        self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None\n        self.tokenizer: Qwen2Tokenizer = None\n        self.siglip2_image_encoder: Siglip2ImageEncoder = None\n        self.dinov3_image_encoder: DINOv3ImageEncoder = None\n        self.image2lora_style: QwenImageImage2LoRAModel = None\n        self.image2lora_coarse: QwenImageImage2LoRAModel = None\n        self.image2lora_fine: QwenImageImage2LoRAModel = None\n        self.processor: Qwen2VLProcessor = None\n        self.in_iteration_models = (\"dit\", \"blockwise_controlnet\")\n        self.units = [\n            QwenImageUnit_ShapeChecker(),\n            QwenImageUnit_NoiseInitializer(),\n            QwenImageUnit_InputImageEmbedder(),\n            QwenImageUnit_Inpaint(),\n            QwenImageUnit_EditImageEmbedder(),\n            QwenImageUnit_LayerInputImageEmbedder(),\n            QwenImageUnit_ContextImageEmbedder(),\n            QwenImageUnit_PromptEmbedder(),\n            QwenImageUnit_EntityControl(),\n            QwenImageUnit_BlockwiseControlNet(),\n        ]\n        self.model_fn = model_fn_qwen_image\n    \n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n        processor_config: ModelConfig = None,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"qwen_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"qwen_image_dit\")\n        pipe.vae = model_pool.fetch_model(\"qwen_image_vae\")\n        pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model(\"qwen_image_blockwise_controlnet\", index=\"all\"))\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            from transformers import Qwen2Tokenizer\n            pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)\n        if processor_config is not None:\n            processor_config.download_if_necessary()\n            from transformers import Qwen2VLProcessor\n            pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)\n        pipe.siglip2_image_encoder = model_pool.fetch_model(\"siglip2_image_encoder\")\n        pipe.dinov3_image_encoder = model_pool.fetch_model(\"dinov3_image_encoder\")\n        pipe.image2lora_style = model_pool.fetch_model(\"qwen_image_image2lora_style\")\n        pipe.image2lora_coarse = model_pool.fetch_model(\"qwen_image_image2lora_coarse\")\n        pipe.image2lora_fine = model_pool.fetch_model(\"qwen_image_image2lora_fine\")\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 4.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Inpaint\n        inpaint_mask: Image.Image = None,\n        inpaint_blur_size: int = None,\n        inpaint_blur_sigma: float = None,\n        # Shape\n        height: int = 1328,\n        width: int = 1328,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        exponential_shift_mu: float = None,\n        # Blockwise ControlNet\n        blockwise_controlnet_inputs: list[ControlNetInput] = None,\n        # EliGen\n        eligen_entity_prompts: list[str] = None,\n        eligen_entity_masks: list[Image.Image] = None,\n        eligen_enable_on_negative: bool = False,\n        # Qwen-Image-Edit\n        edit_image: Image.Image = None,\n        edit_image_auto_resize: bool = True,\n        edit_rope_interpolation: bool = False,\n        # Qwen-Image-Edit-2511\n        zero_cond_t: bool = False,\n        # Qwen-Image-Layered\n        layer_input_image: Image.Image = None,\n        layer_num: int = None,\n        # In-context control\n        context_image: Image.Image = None,\n        # Tile\n        tiled: bool = False,\n        tile_size: int = 128,\n        tile_stride: int = 64,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)\n        \n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"inpaint_mask\": inpaint_mask, \"inpaint_blur_size\": inpaint_blur_size, \"inpaint_blur_sigma\": inpaint_blur_sigma,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n            \"blockwise_controlnet_inputs\": blockwise_controlnet_inputs,\n            \"tiled\": tiled, \"tile_size\": tile_size, \"tile_stride\": tile_stride,\n            \"eligen_entity_prompts\": eligen_entity_prompts, \"eligen_entity_masks\": eligen_entity_masks, \"eligen_enable_on_negative\": eligen_enable_on_negative,\n            \"edit_image\": edit_image, \"edit_image_auto_resize\": edit_image_auto_resize, \"edit_rope_interpolation\": edit_rope_interpolation, \n            \"context_image\": context_image,\n            \"zero_cond_t\": zero_cond_t,\n            \"layer_input_image\": layer_input_image,\n            \"layer_num\": layer_num,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if layer_num is None:\n            image = self.vae_output_to_image(image)\n        else:\n            image = [self.vae_output_to_image(i, pattern=\"C H W\") for i in image]\n        self.load_models_to_device([])\n\n        return image\n\n\nclass QwenImageBlockwiseMultiControlNet(torch.nn.Module):\n    def __init__(self, models: list[QwenImageBlockWiseControlNet]):\n        super().__init__()\n        if not isinstance(models, list):\n            models = [models]\n        self.models = torch.nn.ModuleList(models)\n        for model in models:\n            if hasattr(model, \"vram_management_enabled\") and getattr(model, \"vram_management_enabled\"):\n                self.vram_management_enabled = True\n\n    def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):\n        processed_conditionings = []\n        for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):\n            conditioning = rearrange(conditioning, \"B C (H P) (W Q) -> B (H W) (C P Q)\", P=2, Q=2)\n            model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)\n            processed_conditionings.append(model_output)\n        return processed_conditionings\n\n    def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):\n        res = 0\n        for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):\n            progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)\n            if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):\n                continue\n            model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)\n            res = res + model_output * controlnet_input.scale\n        return res\n\n\nclass QwenImageUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\"),\n            output_params=(\"height\", \"width\"),\n        )\n\n    def process(self, pipe: QwenImagePipeline, height, width):\n        height, width = pipe.check_resize_height_width(height, width)\n        return {\"height\": height, \"width\": width}\n\n\n\nclass QwenImageUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\", \"layer_num\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):\n        if layer_num is None:\n            noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        else:\n            noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\n\nclass QwenImageUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        if isinstance(input_image, list):\n            input_latents = []\n            for image in input_image:\n                image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n                input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))\n            input_latents = torch.concat(input_latents, dim=0)\n        else:\n            image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\nclass QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"layer_input_image\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"layer_input_latents\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):\n        if layer_input_image is None:\n            return {}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n        latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        return {\"layer_input_latents\": latents}\n\n\nclass QwenImageUnit_Inpaint(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"inpaint_mask\", \"height\", \"width\", \"inpaint_blur_size\", \"inpaint_blur_sigma\"),\n            output_params=(\"inpaint_mask\",),\n        )\n\n    def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):\n        if inpaint_mask is None:\n            return {}\n        inpaint_mask = pipe.preprocess_image(inpaint_mask.convert(\"RGB\").resize((width // 8, height // 8)), min_value=0, max_value=1)\n        inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)\n        if inpaint_blur_size is not None and inpaint_blur_sigma is not None:\n            from torchvision.transforms import GaussianBlur\n            blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)\n            inpaint_mask = blur(inpaint_mask)\n        return {\"inpaint_mask\": inpaint_mask}\n\n\nclass QwenImageUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            input_params=(\"edit_image\",),\n            output_params=(\"prompt_emb\", \"prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n        \n    def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):\n        bool_mask = mask.bool()\n        valid_lengths = bool_mask.sum(dim=1)\n        selected = hidden_states[bool_mask]\n        split_result = torch.split(selected, valid_lengths.tolist(), dim=0)\n        return split_result\n    \n    def calculate_dimensions(self, target_area, ratio):\n        width = math.sqrt(target_area * ratio)\n        height = width / ratio\n        width = round(width / 32) * 32\n        height = round(height / 32) * 32\n        return width, height\n    \n    def resize_image(self, image, target_area=384*384):\n        width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1])\n        return image.resize((width, height))\n    \n    def encode_prompt(self, pipe: QwenImagePipeline, prompt):\n        template = \"<|im_start|>system\\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\\n<|im_start|>user\\n{}<|im_end|>\\n<|im_start|>assistant\\n\"\n        drop_idx = 34\n        txt = [template.format(e) for e in prompt]\n        model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors=\"pt\").to(pipe.device)\n        if model_inputs.input_ids.shape[1] >= 1024:\n            print(f\"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.\")\n        hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]\n        split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)\n        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n        return split_hidden_states\n        \n    def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):\n        template =  \"<|im_start|>system\\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\\n<|im_start|>user\\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\\n<|im_start|>assistant\\n\"\n        drop_idx = 64\n        txt = [template.format(e) for e in prompt]\n        model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors=\"pt\").to(pipe.device)\n        hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]\n        split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)\n        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n        return split_hidden_states\n    \n    def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image):\n        template =  \"<|im_start|>system\\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\\n<|im_start|>user\\n{}<|im_end|>\\n<|im_start|>assistant\\n\"\n        drop_idx = 64\n        img_prompt_template = \"Picture {}: <|vision_start|><|image_pad|><|vision_end|>\"\n        base_img_prompt = \"\".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))])\n        txt = [template.format(base_img_prompt + e) for e in prompt]\n        edit_image = [self.resize_image(image) for image in edit_image]\n        model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors=\"pt\").to(pipe.device)\n        hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]\n        split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)\n        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n        return split_hidden_states\n\n    def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:\n        pipe.load_models_to_device(self.onload_model_names)\n        if pipe.text_encoder is not None:\n            prompt = [prompt]\n            if edit_image is None:\n                split_hidden_states = self.encode_prompt(pipe, prompt)\n            elif isinstance(edit_image, Image.Image):\n                split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image)\n            else:\n                split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image)\n            attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]\n            max_seq_len = max([e.size(0) for e in split_hidden_states])\n            prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])\n            encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])\n            prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)\n            return {\"prompt_emb\": prompt_embeds, \"prompt_emb_mask\": encoder_attention_mask}\n        else:\n            return {}\n\n\nclass QwenImageUnit_EntityControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"eligen_entity_prompts\", \"width\", \"height\", \"eligen_enable_on_negative\", \"cfg_scale\"),\n            output_params=(\"entity_prompt_emb\", \"entity_masks\", \"entity_prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):\n        bool_mask = mask.bool()\n        valid_lengths = bool_mask.sum(dim=1)\n        selected = hidden_states[bool_mask]\n        split_result = torch.split(selected, valid_lengths.tolist(), dim=0)\n        return split_result\n\n    def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:\n        if pipe.text_encoder is not None:\n            prompt = [prompt]\n            template = \"<|im_start|>system\\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\\n<|im_start|>user\\n{}<|im_end|>\\n<|im_start|>assistant\\n\"\n            drop_idx = 34\n            txt = [template.format(e) for e in prompt]\n            txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors=\"pt\").to(pipe.device)\n            hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]\n            \n            split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)\n            split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n            attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]\n            max_seq_len = max([e.size(0) for e in split_hidden_states])\n            prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])\n            encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])\n            prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)\n            return {\"prompt_emb\": prompt_embeds, \"prompt_emb_mask\": encoder_attention_mask}\n        else:\n            return {}\n\n    def preprocess_masks(self, pipe, masks, height, width, dim):\n        out_masks = []\n        for mask in masks:\n            mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0\n            mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)\n            out_masks.append(mask)\n        return out_masks\n\n    def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):\n        entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)\n        entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w\n        prompt_embs, prompt_emb_masks = [], []\n        for entity_prompt in entity_prompts:\n            prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)\n            prompt_embs.append(prompt_emb_dict['prompt_emb'])\n            prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])\n        return prompt_embs, prompt_emb_masks, entity_masks\n\n    def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):\n        entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)\n        if enable_eligen_on_negative and cfg_scale != 1.0:\n            entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)\n            entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)\n            entity_masks_nega = entity_masks_posi\n        else:\n            entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None\n        eligen_kwargs_posi = {\"entity_prompt_emb\": entity_prompt_emb_posi, \"entity_masks\": entity_masks_posi, \"entity_prompt_emb_mask\": entity_prompt_emb_posi_mask}\n        eligen_kwargs_nega = {\"entity_prompt_emb\": entity_prompt_emb_nega, \"entity_masks\": entity_masks_nega, \"entity_prompt_emb_mask\": entity_prompt_emb_nega_mask}\n        return eligen_kwargs_posi, eligen_kwargs_nega\n\n    def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        eligen_entity_prompts, eligen_entity_masks = inputs_shared.get(\"eligen_entity_prompts\", None), inputs_shared.get(\"eligen_entity_masks\", None)\n        if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:\n            return inputs_shared, inputs_posi, inputs_nega\n        pipe.load_models_to_device(self.onload_model_names)\n        eligen_enable_on_negative = inputs_shared.get(\"eligen_enable_on_negative\", False)\n        eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,\n            eligen_entity_prompts, eligen_entity_masks, inputs_shared[\"width\"], inputs_shared[\"height\"],\n            eligen_enable_on_negative, inputs_shared[\"cfg_scale\"])\n        inputs_posi.update(eligen_kwargs_posi)\n        if inputs_shared.get(\"cfg_scale\", 1.0) != 1.0:\n            inputs_nega.update(eligen_kwargs_nega)\n        return inputs_shared, inputs_posi, inputs_nega\n\n\n\nclass QwenImageUnit_BlockwiseControlNet(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"blockwise_controlnet_inputs\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"blockwise_controlnet_conditioning\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def apply_controlnet_mask_on_latents(self, pipe, latents, mask):\n        mask = (pipe.preprocess_image(mask) + 1) / 2\n        mask = mask.mean(dim=1, keepdim=True)\n        mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])\n        latents = torch.concat([latents, mask], dim=1)\n        return latents\n\n    def apply_controlnet_mask_on_image(self, pipe, image, mask):\n        mask = mask.resize(image.size)\n        mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()\n        image = np.array(image)\n        image[mask > 0] = 0\n        image = Image.fromarray(image)\n        return image\n\n    def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):\n        if blockwise_controlnet_inputs is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        conditionings = []\n        for controlnet_input in blockwise_controlnet_inputs:\n            image = controlnet_input.image\n            if controlnet_input.inpaint_mask is not None:\n                image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)\n\n            image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n\n            if controlnet_input.inpaint_mask is not None:\n                image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)\n            conditionings.append(image)\n            \n        return {\"blockwise_controlnet_conditioning\": conditionings}\n\n\nclass QwenImageUnit_EditImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"edit_image\", \"tiled\", \"tile_size\", \"tile_stride\", \"edit_image_auto_resize\"),\n            output_params=(\"edit_latents\", \"edit_image\"),\n            onload_model_names=(\"vae\",)\n        )\n\n\n    def calculate_dimensions(self, target_area, ratio):\n        import math\n        width = math.sqrt(target_area * ratio)\n        height = width / ratio\n        width = round(width / 32) * 32\n        height = round(height / 32) * 32\n        return width, height\n\n\n    def edit_image_auto_resize(self, edit_image):\n        calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])\n        return edit_image.resize((calculated_width, calculated_height))\n\n\n    def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):\n        if edit_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        if isinstance(edit_image, Image.Image):\n            resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image\n            edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n            edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        else:\n            resized_edit_image, edit_latents = [], []\n            for image in edit_image:\n                if edit_image_auto_resize:\n                    image = self.edit_image_auto_resize(image)\n                resized_edit_image.append(image)\n                image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)\n                latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n                edit_latents.append(latents)\n        return {\"edit_latents\": edit_latents, \"edit_image\": resized_edit_image}\n\n\nclass QwenImageUnit_Image2LoRAEncode(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"image2lora_images\",),\n            output_params=(\"image2lora_x\", \"image2lora_residual\", \"image2lora_residual_highres\"),\n            onload_model_names=(\"siglip2_image_encoder\", \"dinov3_image_encoder\", \"text_encoder\"),\n        )\n        from ..core.data.operators import ImageCropAndResize\n        self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8)\n        self.processor_highres = ImageCropAndResize(height=1024, width=1024)\n\n    def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):\n        bool_mask = mask.bool()\n        valid_lengths = bool_mask.sum(dim=1)\n        selected = hidden_states[bool_mask]\n        split_result = torch.split(selected, valid_lengths.tolist(), dim=0)\n        return split_result\n\n    def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):\n        prompt = [prompt]\n        template =  \"<|im_start|>system\\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\\n<|im_start|>user\\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\\n<|im_start|>assistant\\n\"\n        drop_idx = 64\n        txt = [template.format(e) for e in prompt]\n        model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors=\"pt\").to(pipe.device)\n        hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]\n        split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)\n        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n        max_seq_len = max([e.size(0) for e in split_hidden_states])\n        prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])\n        prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return prompt_embeds.view(1, -1)\n    \n    def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]):\n        pipe.load_models_to_device([\"siglip2_image_encoder\"])\n        embs = []\n        for image in images:\n            image = self.processor_highres(image)\n            embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))\n        embs = torch.stack(embs)\n        return embs\n    \n    def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]):\n        pipe.load_models_to_device([\"dinov3_image_encoder\"])\n        embs = []\n        for image in images:\n            image = self.processor_highres(image)\n            embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))\n        embs = torch.stack(embs)\n        return embs\n    \n    def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False):\n        pipe.load_models_to_device([\"text_encoder\"])\n        embs = []\n        for image in images:\n            image = self.processor_highres(image) if highres else self.processor_lowres(image)\n            embs.append(self.encode_prompt_edit(pipe, prompt=\"\", edit_image=image))\n        embs = torch.stack(embs)\n        return embs\n\n    def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]):\n        if images is None:\n            return {}\n        if not isinstance(images, list):\n            images = [images]\n        embs_siglip2 = self.encode_images_using_siglip2(pipe, images)\n        embs_dinov3 = self.encode_images_using_dinov3(pipe, images)\n        x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)\n        residual = None\n        residual_highres = None\n        if pipe.image2lora_coarse is not None:\n            residual = self.encode_images_using_qwenvl(pipe, images, highres=False)\n        if pipe.image2lora_fine is not None:\n            residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True)\n        return x, residual, residual_highres\n\n    def process(self, pipe: QwenImagePipeline, image2lora_images):\n        if image2lora_images is None:\n            return {}\n        x, residual, residual_highres = self.encode_images(pipe, image2lora_images)\n        return {\"image2lora_x\": x, \"image2lora_residual\": residual, \"image2lora_residual_highres\": residual_highres}\n\n\nclass QwenImageUnit_Image2LoRADecode(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"image2lora_x\", \"image2lora_residual\", \"image2lora_residual_highres\"),\n            output_params=(\"lora\",),\n            onload_model_names=(\"image2lora_coarse\", \"image2lora_fine\", \"image2lora_style\"),\n        )\n    \n    def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres):\n        if image2lora_x is None:\n            return {}\n        loras = []\n        if pipe.image2lora_style is not None:\n            pipe.load_models_to_device([\"image2lora_style\"])\n            for x in image2lora_x:\n                loras.append(pipe.image2lora_style(x=x, residual=None))\n        if pipe.image2lora_coarse is not None:\n            pipe.load_models_to_device([\"image2lora_coarse\"])\n            for x, residual in zip(image2lora_x, image2lora_residual):\n                loras.append(pipe.image2lora_coarse(x=x, residual=residual))\n        if pipe.image2lora_fine is not None:\n            pipe.load_models_to_device([\"image2lora_fine\"])\n            for x, residual in zip(image2lora_x, image2lora_residual_highres):\n                loras.append(pipe.image2lora_fine(x=x, residual=residual))\n        lora = merge_lora(loras, alpha=1 / len(image2lora_x))\n        return {\"lora\": lora}\n\n\nclass QwenImageUnit_ContextImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"context_image\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\", \"layer_input_image\"),\n            output_params=(\"context_latents\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride, layer_input_image=None):\n        if context_image is None:\n            return {}\n        if layer_input_image is not None:\n            context_image = context_image.convert(\"RGBA\")\n        pipe.load_models_to_device(self.onload_model_names)\n        context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)\n        context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        return {\"context_latents\": context_latents}\n\n\ndef model_fn_qwen_image(\n    dit: QwenImageDiT = None,\n    blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,\n    latents=None,\n    timestep=None,\n    prompt_emb=None,\n    prompt_emb_mask=None,\n    height=None,\n    width=None,\n    blockwise_controlnet_conditioning=None,\n    blockwise_controlnet_inputs=None,\n    progress_id=0,\n    num_inference_steps=1,\n    entity_prompt_emb=None,\n    entity_prompt_emb_mask=None,\n    entity_masks=None,\n    edit_latents=None,\n    layer_input_latents=None,\n    layer_num=None,\n    context_latents=None,\n    enable_fp8_attention=False,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    edit_rope_interpolation=False,\n    zero_cond_t=False,\n    **kwargs\n):\n    if layer_num is None:\n        layer_num = 1\n        img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]\n    else:\n        layer_num = layer_num + 1\n        img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num\n    txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()\n    timestep = timestep / 1000\n    \n    image = rearrange(latents, \"(B N) C (H P) (W Q) -> B (N H W) (C P Q)\", H=height//16, W=width//16, P=2, Q=2, N=layer_num)\n    image_seq_len = image.shape[1]\n\n    if context_latents is not None:\n        img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]\n        context_image = rearrange(context_latents, \"B C (H P) (W Q) -> B (H W) (C P Q)\", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)\n        image = torch.cat([image, context_image], dim=1)\n    if edit_latents is not None:\n        edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents]\n        img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]\n        edit_image = [rearrange(e, \"B C (H P) (W Q) -> B (H W) (C P Q)\", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]\n        image = torch.cat([image] + edit_image, dim=1)\n    if layer_input_latents is not None:\n        layer_num = layer_num + 1\n        img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]\n        layer_input_latents = rearrange(layer_input_latents, \"B C (H P) (W Q) -> B (H W) (C P Q)\", P=2, Q=2)\n        image = torch.cat([image, layer_input_latents], dim=1)\n\n    image = dit.img_in(image)\n    if zero_cond_t:\n        timestep = torch.cat([timestep, timestep * 0], dim=0)\n        modulate_index = torch.tensor(\n            [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],\n            device=timestep.device,\n            dtype=torch.int,\n        )\n    else:\n        modulate_index = None\n    conditioning = dit.time_text_embed(\n        timestep,\n        image.dtype,\n        addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long)\n    )\n\n    if entity_prompt_emb is not None:\n        text, image_rotary_emb, attention_mask = dit.process_entity_masks(\n            latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,\n            entity_masks, height, width, image, img_shapes,\n        )\n    else:\n        text = dit.txt_in(dit.txt_norm(prompt_emb))\n        if edit_rope_interpolation:\n            image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)\n        else:\n            image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)\n        attention_mask = None\n        \n    if blockwise_controlnet_conditioning is not None:\n        blockwise_controlnet_conditioning = blockwise_controlnet.preprocess(\n            blockwise_controlnet_inputs, blockwise_controlnet_conditioning)\n\n    for block_id, block in enumerate(dit.transformer_blocks):\n        text, image = gradient_checkpoint_forward(\n            block,\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            image=image,\n            text=text,\n            temb=conditioning,\n            image_rotary_emb=image_rotary_emb,\n            attention_mask=attention_mask,\n            enable_fp8_attention=enable_fp8_attention,\n            modulate_index=modulate_index,\n        )\n        if blockwise_controlnet_conditioning is not None:\n            image_slice = image[:, :image_seq_len].clone()\n            controlnet_output = blockwise_controlnet.blockwise_forward(\n                image=image_slice, conditionings=blockwise_controlnet_conditioning,\n                controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,\n                progress_id=progress_id, num_inference_steps=num_inference_steps,\n            )\n            image[:, :image_seq_len] = image_slice + controlnet_output\n    \n    if zero_cond_t:\n        conditioning = conditioning.chunk(2, dim=0)[0]\n    image = dit.norm_out(image, conditioning)\n    image = dit.proj_out(image)\n    image = image[:, :image_seq_len]\n    \n    latents = rearrange(image, \"B (N H W) (C P Q) -> (B N) C (H P) (W Q)\", H=height//16, W=width//16, P=2, Q=2, B=1)\n    return latents\n"
  },
  {
    "path": "diffsynth/pipelines/wan_video.py",
    "content": "import torch, types\nimport numpy as np\nfrom PIL import Image\nfrom einops import repeat\nfrom typing import Optional, Union\nfrom einops import rearrange\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom typing import Optional\nfrom typing_extensions import Literal\nfrom transformers import Wav2Vec2Processor\n\nfrom ..core.device.npu_compatible_device import get_device_type\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit\n\nfrom ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d\nfrom ..models.wan_video_dit_s2v import rope_precompute\nfrom ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer\nfrom ..models.wan_video_vae import WanVideoVAE\nfrom ..models.wan_video_image_encoder import WanImageEncoder\nfrom ..models.wan_video_vace import VaceWanModel\nfrom ..models.wan_video_motion_controller import WanMotionControllerModel\nfrom ..models.wan_video_animate_adapter import WanAnimateAdapter\nfrom ..models.wan_video_mot import MotWanModel\nfrom ..models.wav2vec import WanS2VAudioEncoder\nfrom ..models.longcat_video_dit import LongCatVideoTransformer3DModel\n\n\nclass WanVideoPipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1\n        )\n        self.scheduler = FlowMatchScheduler(\"Wan\")\n        self.tokenizer: HuggingfaceTokenizer = None\n        self.audio_processor: Wav2Vec2Processor = None\n        self.text_encoder: WanTextEncoder = None\n        self.image_encoder: WanImageEncoder = None\n        self.dit: WanModel = None\n        self.dit2: WanModel = None\n        self.vae: WanVideoVAE = None\n        self.motion_controller: WanMotionControllerModel = None\n        self.vace: VaceWanModel = None\n        self.vace2: VaceWanModel = None\n        self.vap: MotWanModel = None\n        self.animate_adapter: WanAnimateAdapter = None\n        self.audio_encoder: WanS2VAudioEncoder = None\n        self.in_iteration_models = (\"dit\", \"motion_controller\", \"vace\", \"animate_adapter\", \"vap\")\n        self.in_iteration_models_2 = (\"dit2\", \"motion_controller\", \"vace2\", \"animate_adapter\", \"vap\")\n        self.units = [\n            WanVideoUnit_ShapeChecker(),\n            WanVideoUnit_NoiseInitializer(),\n            WanVideoUnit_PromptEmbedder(),\n            WanVideoUnit_S2V(),\n            WanVideoUnit_InputVideoEmbedder(),\n            WanVideoUnit_ImageEmbedderVAE(),\n            WanVideoUnit_ImageEmbedderCLIP(),\n            WanVideoUnit_ImageEmbedderFused(),\n            WanVideoUnit_FunControl(),\n            WanVideoUnit_FunReference(),\n            WanVideoUnit_FunCameraControl(),\n            WanVideoUnit_SpeedControl(),\n            WanVideoUnit_VACE(),\n            WanVideoUnit_AnimateVideoSplit(),\n            WanVideoUnit_AnimatePoseLatents(),\n            WanVideoUnit_AnimateFacePixelValues(),\n            WanVideoUnit_AnimateInpaint(),\n            WanVideoUnit_VAP(),\n            WanVideoUnit_UnifiedSequenceParallel(),\n            WanVideoUnit_TeaCache(),\n            WanVideoUnit_CfgMerger(),\n            WanVideoUnit_LongCatVideo(),\n            WanVideoUnit_WanToDance_ProcessInputs(),\n            WanVideoUnit_WanToDance_RefImageEmbedder(),\n            WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),\n        ]\n        self.post_units = [\n            WanVideoPostUnit_S2V(),\n        ]\n        self.model_fn = model_fn_wan_video\n\n\n    def enable_usp(self):\n        from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward\n\n        for block in self.dit.blocks:\n            block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)\n        self.dit.forward = types.MethodType(usp_dit_forward, self.dit)\n        if self.dit2 is not None:\n            for block in self.dit2.blocks:\n                block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)\n            self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)\n        self.sp_size = get_sequence_parallel_world_size()\n        self.use_unified_sequence_parallel = True\n\n\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n        audio_processor_config: ModelConfig = None,\n        redirect_common_files: bool = True,\n        use_usp: bool = False,\n        vram_limit: float = None,\n    ):\n        # Redirect model path\n        if redirect_common_files:\n            redirect_dict = {\n                \"models_t5_umt5-xxl-enc-bf16.pth\": (\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", \"models_t5_umt5-xxl-enc-bf16.safetensors\"),\n                \"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\": (\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", \"models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors\"),\n                \"Wan2.1_VAE.pth\": (\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", \"Wan2.1_VAE.safetensors\"),\n                \"Wan2.2_VAE.pth\": (\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", \"Wan2.2_VAE.safetensors\"),\n            }\n            for model_config in model_configs:\n                if model_config.origin_file_pattern is None or model_config.model_id is None:\n                    continue\n                if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]:\n                    print(f\"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.\")\n                    model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]\n                    model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]\n        \n        if use_usp:\n            from ..utils.xfuser import initialize_usp\n            initialize_usp(device)\n            import torch.distributed as dist\n            from ..core.device.npu_compatible_device import get_device_name\n            if dist.is_available() and dist.is_initialized():\n                device = get_device_name()\n        # Initialize pipeline\n        pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"wan_video_text_encoder\")\n        dit = model_pool.fetch_model(\"wan_video_dit\", index=2)\n        if isinstance(dit, list):\n            pipe.dit, pipe.dit2 = dit\n        else:\n            pipe.dit = dit\n        pipe.vae = model_pool.fetch_model(\"wan_video_vae\")\n        pipe.image_encoder = model_pool.fetch_model(\"wan_video_image_encoder\")\n        pipe.motion_controller = model_pool.fetch_model(\"wan_video_motion_controller\")\n        vace = model_pool.fetch_model(\"wan_video_vace\", index=2)\n        if isinstance(vace, list):\n            pipe.vace, pipe.vace2 = vace\n        else:\n            pipe.vace = vace\n        pipe.vap = model_pool.fetch_model(\"wan_video_vap\")\n        pipe.audio_encoder = model_pool.fetch_model(\"wans2v_audio_encoder\")\n        pipe.animate_adapter = model_pool.fetch_model(\"wan_video_animate_adapter\")\n\n        # Size division factor\n        if pipe.vae is not None:\n            pipe.height_division_factor = pipe.vae.upsampling_factor * 2\n            pipe.width_division_factor = pipe.vae.upsampling_factor * 2\n\n        # Initialize tokenizer and processor\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')\n        if audio_processor_config is not None:\n            audio_processor_config.download_if_necessary()\n            pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)\n        \n        # Unified Sequence Parallel\n        if use_usp: pipe.enable_usp()\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: Optional[str] = \"\",\n        # Image-to-video\n        input_image: Optional[Image.Image] = None,\n        # First-last-frame-to-video\n        end_image: Optional[Image.Image] = None,\n        # Video-to-video\n        input_video: Optional[list[Image.Image]] = None,\n        denoising_strength: Optional[float] = 1.0,\n        # Speech-to-video\n        input_audio: Optional[np.array] = None,\n        audio_embeds: Optional[torch.Tensor] = None,\n        audio_sample_rate: Optional[int] = 16000,\n        s2v_pose_video: Optional[list[Image.Image]] = None,\n        s2v_pose_latents: Optional[torch.Tensor] = None,\n        motion_video: Optional[list[Image.Image]] = None,\n        # ControlNet\n        control_video: Optional[list[Image.Image]] = None,\n        reference_image: Optional[Image.Image] = None,\n        # Camera control\n        camera_control_direction: Optional[Literal[\"Left\", \"Right\", \"Up\", \"Down\", \"LeftUp\", \"LeftDown\", \"RightUp\", \"RightDown\"]] = None,\n        camera_control_speed: Optional[float] = 1/54,\n        camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),\n        # VACE\n        vace_video: Optional[list[Image.Image]] = None,\n        vace_video_mask: Optional[Image.Image] = None,\n        vace_reference_image: Optional[Image.Image] = None,\n        vace_scale: Optional[float] = 1.0,\n        # Animate\n        animate_pose_video: Optional[list[Image.Image]] = None,\n        animate_face_video: Optional[list[Image.Image]] = None,\n        animate_inpaint_video: Optional[list[Image.Image]] = None,\n        animate_mask_video: Optional[list[Image.Image]] = None,\n        # VAP\n        vap_video: Optional[list[Image.Image]] = None,\n        vap_prompt: Optional[str] = \" \",\n        negative_vap_prompt: Optional[str] = \" \",\n        # Randomness\n        seed: Optional[int] = None,\n        rand_device: Optional[str] = \"cpu\",\n        # Shape\n        height: Optional[int] = 480,\n        width: Optional[int] = 832,\n        num_frames=81,\n        # Classifier-free guidance\n        cfg_scale: Optional[float] = 5.0,\n        cfg_merge: Optional[bool] = False,\n        # Boundary\n        switch_DiT_boundary: Optional[float] = 0.875,\n        # Scheduler\n        num_inference_steps: Optional[int] = 50,\n        sigma_shift: Optional[float] = 5.0,\n        # Speed control\n        motion_bucket_id: Optional[int] = None,\n        # LongCat-Video\n        longcat_video: Optional[list[Image.Image]] = None,\n        # VAE tiling\n        tiled: Optional[bool] = True,\n        tile_size: Optional[tuple[int, int]] = (30, 52),\n        tile_stride: Optional[tuple[int, int]] = (15, 26),\n        # Sliding window\n        sliding_window_size: Optional[int] = None,\n        sliding_window_stride: Optional[int] = None,\n        # Teacache\n        tea_cache_l1_thresh: Optional[float] = None,\n        tea_cache_model_id: Optional[str] = \"\",\n        # WanToDance\n        wantodance_music_path: Optional[str] = None,\n        wantodance_reference_image: Optional[Image.Image] = None,\n        wantodance_fps: Optional[float] = 30,\n        wantodance_keyframes: Optional[list[Image.Image]] = None,\n        wantodance_keyframes_mask: Optional[list[int]] = None,\n        framewise_decoding: bool = False,\n        # progress_bar\n        progress_bar_cmd=tqdm,\n        output_type: Optional[Literal[\"quantized\", \"floatpoint\"]] = \"quantized\",\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)\n        \n        # Inputs\n        inputs_posi = {\n            \"prompt\": prompt,\n            \"vap_prompt\": vap_prompt,\n            \"tea_cache_l1_thresh\": tea_cache_l1_thresh, \"tea_cache_model_id\": tea_cache_model_id, \"num_inference_steps\": num_inference_steps,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n            \"negative_vap_prompt\": negative_vap_prompt,\n            \"tea_cache_l1_thresh\": tea_cache_l1_thresh, \"tea_cache_model_id\": tea_cache_model_id, \"num_inference_steps\": num_inference_steps,\n        }\n        inputs_shared = {\n            \"input_image\": input_image,\n            \"end_image\": end_image,\n            \"input_video\": input_video, \"denoising_strength\": denoising_strength,\n            \"control_video\": control_video, \"reference_image\": reference_image,\n            \"camera_control_direction\": camera_control_direction, \"camera_control_speed\": camera_control_speed, \"camera_control_origin\": camera_control_origin,\n            \"vace_video\": vace_video, \"vace_video_mask\": vace_video_mask, \"vace_reference_image\": vace_reference_image, \"vace_scale\": vace_scale,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"height\": height, \"width\": width, \"num_frames\": num_frames,\n            \"cfg_scale\": cfg_scale, \"cfg_merge\": cfg_merge,\n            \"sigma_shift\": sigma_shift,\n            \"motion_bucket_id\": motion_bucket_id,\n            \"longcat_video\": longcat_video,\n            \"tiled\": tiled, \"tile_size\": tile_size, \"tile_stride\": tile_stride,\n            \"sliding_window_size\": sliding_window_size, \"sliding_window_stride\": sliding_window_stride,\n            \"input_audio\": input_audio, \"audio_sample_rate\": audio_sample_rate, \"s2v_pose_video\": s2v_pose_video, \"audio_embeds\": audio_embeds, \"s2v_pose_latents\": s2v_pose_latents, \"motion_video\": motion_video,\n            \"animate_pose_video\": animate_pose_video, \"animate_face_video\": animate_face_video, \"animate_inpaint_video\": animate_inpaint_video, \"animate_mask_video\": animate_mask_video,\n            \"vap_video\": vap_video, \n            \"wantodance_music_path\": wantodance_music_path, \"wantodance_reference_image\": wantodance_reference_image, \"wantodance_fps\": wantodance_fps,\n            \"wantodance_keyframes\": wantodance_keyframes, \"wantodance_keyframes_mask\": wantodance_keyframes_mask,\n            \"framewise_decoding\": framewise_decoding,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            # Switch DiT if necessary\n            if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models[\"dit\"] is self.dit2:\n                self.load_models_to_device(self.in_iteration_models_2)\n                models[\"dit\"] = self.dit2\n                models[\"vace\"] = self.vace2\n                \n            # Timestep\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            \n            # Inference\n            noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)\n            if cfg_scale != 1.0:\n                if cfg_merge:\n                    noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)\n                else:\n                    noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)\n                noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)\n            else:\n                noise_pred = noise_pred_posi\n\n            # Scheduler\n            inputs_shared[\"latents\"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared[\"latents\"])\n            if \"first_frame_latents\" in inputs_shared:\n                inputs_shared[\"latents\"][:, :, 0:1] = inputs_shared[\"first_frame_latents\"]\n        \n        # VACE (TODO: remove it)\n        if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):\n            if vace_reference_image is not None and isinstance(vace_reference_image, list):\n                f = len(vace_reference_image)\n            else:\n                f = 1\n            inputs_shared[\"latents\"] = inputs_shared[\"latents\"][:, :, f:]\n        # post-denoising, pre-decoding processing logic\n        for unit in self.post_units:\n            inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n        # Decode\n        self.load_models_to_device(['vae'])\n        if framewise_decoding:\n            video = self.vae.decode_framewise(inputs_shared[\"latents\"], device=self.device)\n        else:\n            video = self.vae.decode(inputs_shared[\"latents\"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if output_type == \"quantized\":\n            video = self.vae_output_to_video(video)\n        elif output_type == \"floatpoint\":\n            pass\n        self.load_models_to_device([])\n        return video\n\n\n\nclass WanVideoUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\"),\n            output_params=(\"height\", \"width\", \"num_frames\"),\n        )\n\n    def process(self, pipe: WanVideoPipeline, height, width, num_frames):\n        height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)\n        return {\"height\": height, \"width\": width, \"num_frames\": num_frames}\n\n\n\nclass WanVideoUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\", \"seed\", \"rand_device\", \"vace_reference_image\"),\n            output_params=(\"noise\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):\n        length = (num_frames - 1) // 4 + 1\n        if vace_reference_image is not None:\n            f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1\n            length += f\n        shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)\n        noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)\n        if vace_reference_image is not None:\n            noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)\n        return {\"noise\": noise}\n    \n\n\nclass WanVideoUnit_InputVideoEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_video\", \"noise\", \"tiled\", \"tile_size\", \"tile_stride\", \"vace_reference_image\", \"framewise_decoding\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):\n        if input_video is None:\n            return {\"latents\": noise}\n        pipe.load_models_to_device(self.onload_model_names)\n        input_video = pipe.preprocess_video(input_video)\n        if framewise_decoding:\n            input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)\n        else:\n            input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n        if vace_reference_image is not None:\n            if not isinstance(vace_reference_image, list):\n                vace_reference_image = [vace_reference_image]\n            vace_reference_image = pipe.preprocess_video(vace_reference_image)\n            vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)\n            input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents}\n\n\n\nclass WanVideoUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\", \"positive\": \"positive\"},\n            input_params_nega={\"prompt\": \"negative_prompt\", \"positive\": \"positive\"},\n            output_params=(\"context\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n    \n    def encode_prompt(self, pipe: WanVideoPipeline, prompt):\n        ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True)\n        ids = ids.to(pipe.device)\n        mask = mask.to(pipe.device)\n        seq_lens = mask.gt(0).sum(dim=1).long()\n        prompt_emb = pipe.text_encoder(ids, mask)\n        for i, v in enumerate(seq_lens):\n            prompt_emb[:, v:] = 0\n        return prompt_emb\n\n    def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:\n        pipe.load_models_to_device(self.onload_model_names)\n        prompt_emb = self.encode_prompt(pipe, prompt)\n        return {\"context\": prompt_emb}\n\n\n\nclass WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"end_image\", \"height\", \"width\"),\n            output_params=(\"clip_feature\",),\n            onload_model_names=(\"image_encoder\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):\n        if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)\n        clip_context = pipe.image_encoder.encode_image([image])\n        if end_image is not None:\n            end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)\n            if pipe.dit.has_image_pos_emb:\n                clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)\n        clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"clip_feature\": clip_context}\n    \n\n\nclass WanVideoUnit_ImageEmbedderVAE(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"end_image\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"y\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):\n        if input_image is None or not pipe.dit.require_vae_embedding:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)\n        msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)\n        msk[:, 1:] = 0\n        if end_image is not None:\n            end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)\n            vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)\n            msk[:, -1:] = 1\n        else:\n            vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)\n\n        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n        msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)\n        msk = msk.transpose(1, 2)[0]\n        \n        y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        y = torch.concat([msk, y])\n        y = y.unsqueeze(0)\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"y\": y}\n\n\n\nclass WanVideoUnit_ImageEmbedderFused(PipelineUnit):\n    \"\"\"\n    Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.\n    \"\"\"\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"latents\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"latents\", \"fuse_vae_embedding_in_latents\", \"first_frame_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):\n        if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)\n        z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        latents[:, :, 0: 1] = z\n        return {\"latents\": latents, \"fuse_vae_embedding_in_latents\": True, \"first_frame_latents\": z}\n\n\n\nclass WanVideoUnit_FunControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"control_video\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\", \"clip_feature\", \"y\", \"latents\"),\n            output_params=(\"clip_feature\", \"y\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):\n        if control_video is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        control_video = pipe.preprocess_video(control_video)\n        control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n        control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)\n        y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]\n        if clip_feature is None or y is None:\n            clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)\n            y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)\n        else:\n            y = y[:, -y_dim:]\n        y = torch.concat([control_latents, y], dim=1)\n        return {\"clip_feature\": clip_feature, \"y\": y}\n    \n\n\nclass WanVideoUnit_FunReference(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"reference_image\", \"height\", \"width\", \"reference_image\"),\n            output_params=(\"reference_latents\", \"clip_feature\"),\n            onload_model_names=(\"vae\", \"image_encoder\")\n        )\n\n    def process(self, pipe: WanVideoPipeline, reference_image, height, width):\n        if reference_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        reference_image = reference_image.resize((width, height))\n        reference_latents = pipe.preprocess_video([reference_image])\n        reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)\n        if pipe.image_encoder is None:\n            return {\"reference_latents\": reference_latents}\n        clip_feature = pipe.preprocess_image(reference_image)\n        clip_feature = pipe.image_encoder.encode_image([clip_feature])\n        return {\"reference_latents\": reference_latents, \"clip_feature\": clip_feature}\n\n\n\nclass WanVideoUnit_FunCameraControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"num_frames\", \"camera_control_direction\", \"camera_control_speed\", \"camera_control_origin\", \"latents\", \"input_image\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"control_camera_latents_input\", \"y\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride):\n        if camera_control_direction is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(\n            camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)\n        \n        control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)\n        control_camera_latents = torch.concat(\n            [\n                torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),\n                control_camera_video[:, :, 1:]\n            ], dim=2\n        ).transpose(1, 2)\n        b, f, c, h, w = control_camera_latents.shape\n        control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)\n        control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)\n        control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)\n        \n        input_image = input_image.resize((width, height))\n        input_latents = pipe.preprocess_video([input_image])\n        input_latents = pipe.vae.encode(input_latents, device=pipe.device)\n        y = torch.zeros_like(latents).to(pipe.device)\n        y[:, :, :1] = input_latents\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n\n        if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:\n            image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)\n            vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)\n            y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]\n            y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n            msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)\n            msk[:, 1:] = 0\n            msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n            msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)\n            msk = msk.transpose(1, 2)[0]\n            y = torch.cat([msk,y])\n            y = y.unsqueeze(0)\n            y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"control_camera_latents_input\": control_camera_latents_input, \"y\": y}\n\n\n\nclass WanVideoUnit_SpeedControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"motion_bucket_id\",),\n            output_params=(\"motion_bucket_id\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, motion_bucket_id):\n        if motion_bucket_id is None:\n            return {}\n        motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"motion_bucket_id\": motion_bucket_id}\n\n\n\nclass WanVideoUnit_VACE(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"vace_video\", \"vace_video_mask\", \"vace_reference_image\", \"vace_scale\", \"height\", \"width\", \"num_frames\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"vace_context\", \"vace_scale\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(\n        self,\n        pipe: WanVideoPipeline,\n        vace_video, vace_video_mask, vace_reference_image, vace_scale,\n        height, width, num_frames,\n        tiled, tile_size, tile_stride\n    ):\n        if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:\n            pipe.load_models_to_device([\"vae\"])\n            if vace_video is None:\n                vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)\n            else:\n                vace_video = pipe.preprocess_video(vace_video)\n            \n            if vace_video_mask is None:\n                vace_video_mask = torch.ones_like(vace_video)\n            else:\n                vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)\n            \n            inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask\n            reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)\n            inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n            reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n            vace_video_latents = torch.concat((inactive, reactive), dim=1)\n            \n            vace_mask_latents = rearrange(vace_video_mask[0,0], \"T (H P) (W Q) -> 1 (P Q) T H W\", P=8, Q=8)\n            vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')\n            \n            if vace_reference_image is None:\n                pass\n            else:\n                if not isinstance(vace_reference_image,list):\n                    vace_reference_image = [vace_reference_image]\n\n                vace_reference_image = pipe.preprocess_video(vace_reference_image)\n\n                bs, c, f, h, w = vace_reference_image.shape\n                new_vace_ref_images = []\n                for j in range(f):\n                    new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])\n                vace_reference_image = new_vace_ref_images\n                \n                vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n                vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)\n                vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]\n\n                vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)\n                vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)\n            \n            vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)\n            return {\"vace_context\": vace_context, \"vace_scale\": vace_scale}\n        else:\n            return {\"vace_context\": None, \"vace_scale\": vace_scale}\n\n\nclass WanVideoUnit_VAP(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            onload_model_names=(\"text_encoder\", \"vae\", \"image_encoder\"),\n            input_params=(\"vap_video\", \"vap_prompt\", \"negative_vap_prompt\", \"end_image\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"vap_clip_feature\", \"vap_hidden_state\", \"context_vap\")\n        )\n        \n    def encode_prompt(self, pipe: WanVideoPipeline, prompt):\n        ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True)\n        ids = ids.to(pipe.device)\n        mask = mask.to(pipe.device)\n        seq_lens = mask.gt(0).sum(dim=1).long()\n        prompt_emb = pipe.text_encoder(ids, mask)\n        for i, v in enumerate(seq_lens):\n            prompt_emb[:, v:] = 0\n        return prompt_emb\n\n    def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if inputs_shared.get(\"vap_video\") is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        else:\n            # 1. encode vap prompt\n            pipe.load_models_to_device([\"text_encoder\"])\n            vap_prompt, negative_vap_prompt = inputs_posi.get(\"vap_prompt\", \"\"), inputs_nega.get(\"negative_vap_prompt\", \"\")\n            vap_prompt_emb = self.encode_prompt(pipe, vap_prompt)\n            negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt)\n            inputs_posi.update({\"context_vap\":vap_prompt_emb})\n            inputs_nega.update({\"context_vap\":negative_vap_prompt_emb})\n            # 2. prepare vap image clip embedding\n            pipe.load_models_to_device([\"vae\", \"image_encoder\"])\n            vap_video, end_image = inputs_shared.get(\"vap_video\"), inputs_shared.get(\"end_image\")\n\n            num_frames, height, width = inputs_shared.get(\"num_frames\"),inputs_shared.get(\"height\"), inputs_shared.get(\"width\")\n            \n            image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)\n\n            vap_clip_context = pipe.image_encoder.encode_image([image_vap])\n            if end_image is not None:\n                vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)\n                if pipe.dit.has_image_pos_emb:\n                    vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)\n            vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)\n            inputs_shared.update({\"vap_clip_feature\":vap_clip_context})\n\n            # 3. prepare vap latents            \n            msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)\n            msk[:, 1:] = 0\n            if end_image is not None:\n                msk[:, -1:] = 1\n                last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)\n                vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)\n            else:\n                vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)\n            \n            msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n            msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)\n            msk = msk.transpose(1, 2)[0]\n\n            tiled,tile_size,tile_stride = inputs_shared.get(\"tiled\"), inputs_shared.get(\"tile_size\"), inputs_shared.get(\"tile_stride\")\n\n            y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]\n            y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n            y = torch.concat([msk, y])\n            y = y.unsqueeze(0)\n            y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n\n            vap_video = pipe.preprocess_video(vap_video)\n            vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n\n            vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)\n            inputs_shared.update({\"vap_hidden_state\":vap_latent})\n\n            return inputs_shared, inputs_posi, inputs_nega\n\n\n\nclass WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(), output_params=(\"use_unified_sequence_parallel\",))\n\n    def process(self, pipe: WanVideoPipeline):\n        if hasattr(pipe, \"use_unified_sequence_parallel\"):\n            if pipe.use_unified_sequence_parallel:\n                return {\"use_unified_sequence_parallel\": True}\n        return {}\n\n\n\nclass WanVideoUnit_TeaCache(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"num_inference_steps\": \"num_inference_steps\", \"tea_cache_l1_thresh\": \"tea_cache_l1_thresh\", \"tea_cache_model_id\": \"tea_cache_model_id\"},\n            input_params_nega={\"num_inference_steps\": \"num_inference_steps\", \"tea_cache_l1_thresh\": \"tea_cache_l1_thresh\", \"tea_cache_model_id\": \"tea_cache_model_id\"},\n            output_params=(\"tea_cache\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id):\n        if tea_cache_l1_thresh is None:\n            return {}\n        return {\"tea_cache\": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)}\n\n\n\nclass WanVideoUnit_CfgMerger(PipelineUnit):\n    def __init__(self):\n        super().__init__(take_over=True)\n        self.concat_tensor_names = [\"context\", \"clip_feature\", \"y\", \"reference_latents\"]\n\n    def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if not inputs_shared[\"cfg_merge\"]:\n            return inputs_shared, inputs_posi, inputs_nega\n        for name in self.concat_tensor_names:\n            tensor_posi = inputs_posi.get(name)\n            tensor_nega = inputs_nega.get(name)\n            tensor_shared = inputs_shared.get(name)\n            if tensor_posi is not None and tensor_nega is not None:\n                inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)\n            elif tensor_shared is not None:\n                inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)\n        inputs_posi.clear()\n        inputs_nega.clear()\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass WanVideoUnit_S2V(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            onload_model_names=(\"audio_encoder\", \"vae\",),\n            input_params=(\"input_audio\", \"audio_embeds\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\", \"audio_sample_rate\", \"s2v_pose_video\", \"s2v_pose_latents\", \"motion_video\"),\n            output_params=(\"audio_embeds\", \"motion_latents\", \"drop_motion_frames\", \"s2v_pose_latents\"),\n        )\n\n    def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):\n        if audio_embeds is not None:\n            return {\"audio_embeds\": audio_embeds}\n        pipe.load_models_to_device([\"audio_encoder\"])\n        audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)\n        if return_all:\n            return audio_embeds\n        else:\n            return {\"audio_embeds\": audio_embeds[0]}\n\n    def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):\n        pipe.load_models_to_device([\"vae\"])\n        motion_frames = 73\n        kwargs = {}\n        if motion_video is not None:\n            assert motion_video.shape[2] == motion_frames, f\"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}\"\n            motion_latents = motion_video\n            kwargs[\"drop_motion_frames\"] = False\n        else:\n            motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)\n            kwargs[\"drop_motion_frames\"] = True\n        motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n        kwargs.update({\"motion_latents\": motion_latents})\n        return kwargs\n\n    def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):\n        if s2v_pose_latents is not None:\n            return {\"s2v_pose_latents\": s2v_pose_latents}\n        if s2v_pose_video is None:\n            return {\"s2v_pose_latents\": None}\n        pipe.load_models_to_device([\"vae\"])\n        infer_frames = num_frames - 1\n        input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]\n        # pad if not enough frames\n        padding_frames = infer_frames * num_repeats - input_video.shape[2]\n        input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)\n        input_videos = input_video.chunk(num_repeats, dim=2)\n        pose_conds = []\n        for r in range(num_repeats):\n            cond = input_videos[r]\n            cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)\n            cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n            pose_conds.append(cond_latents[:,:,1:])\n        if return_all:\n            return pose_conds\n        else:\n            return {\"s2v_pose_latents\": pose_conds[0]}\n\n    def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if (inputs_shared.get(\"input_audio\") is None and inputs_shared.get(\"audio_embeds\") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get(\"num_frames\"), inputs_shared.get(\"height\"), inputs_shared.get(\"width\"), inputs_shared.get(\"tiled\"), inputs_shared.get(\"tile_size\"), inputs_shared.get(\"tile_stride\")\n        input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop(\"input_audio\", None), inputs_shared.pop(\"audio_embeds\", None), inputs_shared.get(\"audio_sample_rate\", 16000)\n        s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop(\"s2v_pose_video\", None), inputs_shared.pop(\"s2v_pose_latents\", None), inputs_shared.pop(\"motion_video\", None)\n\n        audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)\n        inputs_posi.update(audio_input_positive)\n        inputs_nega.update({\"audio_embeds\": 0.0 * audio_input_positive[\"audio_embeds\"]})\n\n        inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))\n        inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))\n        return inputs_shared, inputs_posi, inputs_nega\n\n    @staticmethod\n    def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):\n        assert pipe.audio_encoder is not None and pipe.audio_processor is not None, \"Please load audio encoder and audio processor first.\"\n        shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)\n        height, width, num_frames = shapes[\"height\"], shapes[\"width\"], shapes[\"num_frames\"]\n        unit = WanVideoUnit_S2V()\n        audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)\n        pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        pose_latents = None if s2v_pose_video is None else pose_latents\n        return audio_embeds, pose_latents, len(audio_embeds)\n\n\nclass WanVideoPostUnit_S2V(PipelineUnit):\n    def __init__(self):\n        super().__init__(input_params=(\"latents\", \"motion_latents\", \"drop_motion_frames\"))\n\n    def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):\n        if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:\n            return {}\n        latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)\n        return {\"latents\": latents}\n\n\nclass WanVideoUnit_AnimateVideoSplit(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_video\", \"animate_pose_video\", \"animate_face_video\", \"animate_inpaint_video\", \"animate_mask_video\"),\n            output_params=(\"animate_pose_video\", \"animate_face_video\", \"animate_inpaint_video\", \"animate_mask_video\")\n        )\n\n    def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):\n        if input_video is None:\n            return {}\n        if animate_pose_video is not None:\n            animate_pose_video = animate_pose_video[:len(input_video) - 4]\n        if animate_face_video is not None:\n            animate_face_video = animate_face_video[:len(input_video) - 4]\n        if animate_inpaint_video is not None:\n            animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]\n        if animate_mask_video is not None:\n            animate_mask_video = animate_mask_video[:len(input_video) - 4]\n        return {\"animate_pose_video\": animate_pose_video, \"animate_face_video\": animate_face_video, \"animate_inpaint_video\": animate_inpaint_video, \"animate_mask_video\": animate_mask_video}\n\n\nclass WanVideoUnit_AnimatePoseLatents(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"animate_pose_video\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"pose_latents\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):\n        if animate_pose_video is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        animate_pose_video = pipe.preprocess_video(animate_pose_video)\n        pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"pose_latents\": pose_latents}\n\n\nclass WanVideoUnit_AnimateFacePixelValues(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"animate_face_video\",),\n            output_params=(\"face_pixel_values\"),\n        )\n\n    def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if inputs_shared.get(\"animate_face_video\", None) is None:\n            return inputs_shared, inputs_posi, inputs_nega\n        inputs_posi[\"face_pixel_values\"] = pipe.preprocess_video(inputs_shared[\"animate_face_video\"])\n        inputs_nega[\"face_pixel_values\"] = torch.zeros_like(inputs_posi[\"face_pixel_values\"]) - 1\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass WanVideoUnit_AnimateInpaint(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"animate_inpaint_video\", \"animate_mask_video\", \"input_image\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"y\",),\n            onload_model_names=(\"vae\",)\n        )\n        \n    def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):\n        if mask_pixel_values is None:\n            msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)\n        else:\n            msk = mask_pixel_values.clone()\n        msk[:, :mask_len] = 1\n        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)\n        msk = msk.transpose(1, 2)[0]\n        return msk\n\n    def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):\n        if animate_inpaint_video is None or animate_mask_video is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n\n        bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)\n        y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)\n        _, lat_t, lat_h, lat_w = y_reft.shape\n        \n        ref_pixel_values = pipe.preprocess_video([input_image])\n        ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)\n        mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)\n        y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)\n        \n        mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)\n        mask_pixel_values = rearrange(mask_pixel_values, \"b c t h w -> (b t) c h w\")\n        mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')\n        mask_pixel_values = rearrange(mask_pixel_values, \"(b t) c h w -> b t c h w\", b=1)[:,:,0]\n        msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)\n        \n        y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)\n        y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)\n        return {\"y\": y}\n\n\nclass WanVideoUnit_LongCatVideo(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"longcat_video\",),\n            output_params=(\"longcat_latents\",),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: WanVideoPipeline, longcat_video):\n        if longcat_video is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        longcat_video = pipe.preprocess_video(longcat_video)\n        longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"longcat_latents\": longcat_latents}\n\n\nclass WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n        )\n\n    def get_music_base_feature(self, music_path, fps=30):\n        import librosa\n        hop_length = 512\n        sr = fps * hop_length\n        data, sr = librosa.load(music_path, sr=sr)\n        sr = 22050 \n        envelope = librosa.onset.onset_strength(y=data, sr=sr)\n        mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T  \n        chroma = librosa.feature.chroma_cens(\n            y=data, sr=sr, hop_length=hop_length, n_chroma=12\n        ).T \n        peak_idxs = librosa.onset.onset_detect(\n            onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length\n        )\n        peak_onehot = np.zeros_like(envelope, dtype=np.float32)\n        peak_onehot[peak_idxs] = 1.0\n        start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]\n        _, beat_idxs = librosa.beat.beat_track(\n            onset_envelope=envelope,\n            sr=sr,\n            hop_length=hop_length,\n            start_bpm=start_bpm,\n            tightness=100,\n        )\n        beat_onehot = np.zeros_like(envelope, dtype=np.float32)\n        beat_onehot[beat_idxs] = 1.0  \n        audio_feature = np.concatenate(\n            [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],\n            axis=-1,\n        )\n        return torch.from_numpy(audio_feature)\n\n    def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):\n        if pipe.dit.wantodance_enable_global:\n            inputs_nega[\"skip_9th_layer\"] = True\n        if inputs_shared.get(\"wantodance_music_path\", None) is not None:\n            inputs_shared[\"music_feature\"] = self.get_music_base_feature(inputs_shared[\"wantodance_music_path\"]).to(dtype=pipe.torch_dtype, device=pipe.device)\n        return inputs_shared, inputs_posi, inputs_nega\n\n\nclass WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"wantodance_reference_image\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"wantodance_refimage_feature\",),\n            onload_model_names=(\"image_encoder\", \"vae\")\n        )\n\n    def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):\n        if wantodance_reference_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        if isinstance(wantodance_reference_image, list):\n            wantodance_reference_image = wantodance_reference_image[0]\n        image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1\n        refimage_feature = pipe.image_encoder.encode_image([image])\n        refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"wantodance_refimage_feature\": refimage_feature}\n\n\nclass WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"wantodance_keyframes\", \"wantodance_keyframes_mask\", \"num_frames\", \"height\", \"width\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"clip_feature\", \"y\"),\n            onload_model_names=(\"image_encoder\", \"vae\")\n        )\n\n    def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):\n        if wantodance_keyframes is None:\n            return {}\n        wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)\n        pipe.load_models_to_device(self.onload_model_names)\n        images = []\n        for input_image in wantodance_keyframes:\n            input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)\n            images.append(input_image)\n    \n        clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入\n        msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)\n        msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1\n        \n        images = [image.transpose(0, 1) for image in images]  # 3, num_frames, h, w\n        images = torch.concat(images, dim=1) \n        vae_input = images\n\n        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3\n        msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)\n        msk = msk.transpose(1, 2)[0]\n        \n        y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        y = torch.concat([msk, y])\n        y = y.unsqueeze(0)\n        clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)\n        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)\n        return {\"clip_feature\": clip_context, \"y\": y}\n\n\nclass TeaCache:\n    def __init__(self, num_inference_steps, rel_l1_thresh, model_id):\n        self.num_inference_steps = num_inference_steps\n        self.step = 0\n        self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = None\n        self.rel_l1_thresh = rel_l1_thresh\n        self.previous_residual = None\n        self.previous_hidden_states = None\n        \n        self.coefficients_dict = {\n            \"Wan2.1-T2V-1.3B\": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],\n            \"Wan2.1-T2V-14B\": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],\n            \"Wan2.1-I2V-14B-480P\": [2.57151496e+05, -3.54229917e+04,  1.40286849e+03, -1.35890334e+01, 1.32517977e-01],\n            \"Wan2.1-I2V-14B-720P\": [ 8.10705460e+03,  2.13393892e+03, -3.72934672e+02,  1.66203073e+01, -4.17769401e-02],\n        }\n        if model_id not in self.coefficients_dict:\n            supported_model_ids = \", \".join([i for i in self.coefficients_dict])\n            raise ValueError(f\"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).\")\n        self.coefficients = self.coefficients_dict[model_id]\n\n    def check(self, dit: WanModel, x, t_mod):\n        modulated_inp = t_mod.clone()\n        if self.step == 0 or self.step == self.num_inference_steps - 1:\n            should_calc = True\n            self.accumulated_rel_l1_distance = 0\n        else:\n            coefficients = self.coefficients\n            rescale_func = np.poly1d(coefficients)\n            self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())\n            if self.accumulated_rel_l1_distance < self.rel_l1_thresh:\n                should_calc = False\n            else:\n                should_calc = True\n                self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = modulated_inp\n        self.step += 1\n        if self.step == self.num_inference_steps:\n            self.step = 0\n        if should_calc:\n            self.previous_hidden_states = x.clone()\n        return not should_calc\n\n    def store(self, hidden_states):\n        self.previous_residual = hidden_states - self.previous_hidden_states\n        self.previous_hidden_states = None\n\n    def update(self, hidden_states):\n        hidden_states = hidden_states + self.previous_residual\n        return hidden_states\n\n\n\nclass TemporalTiler_BCTHW:\n    def __init__(self):\n        pass\n\n    def build_1d_mask(self, length, left_bound, right_bound, border_width):\n        x = torch.ones((length,))\n        if border_width == 0:\n            return x\n        \n        shift = 0.5\n        if not left_bound:\n            x[:border_width] = (torch.arange(border_width) + shift) / border_width\n        if not right_bound:\n            x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))\n        return x\n\n    def build_mask(self, data, is_bound, border_width):\n        _, _, T, _, _ = data.shape\n        t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])\n        mask = repeat(t, \"T -> 1 1 T 1 1\")\n        return mask\n    \n    def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):\n        tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]\n        tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}\n        B, C, T, H, W = tensor_dict[tensor_names[0]].shape\n        if batch_size is not None:\n            B *= batch_size\n        data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype\n        value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)\n        weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)\n        for t in range(0, T, sliding_window_stride):\n            if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T:\n                continue\n            t_ = min(t + sliding_window_size, T)\n            model_kwargs.update({\n                tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \\\n                    for tensor_name in tensor_names\n            })\n            model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype)\n            mask = self.build_mask(\n                model_output,\n                is_bound=(t == 0, t_ == T),\n                border_width=(sliding_window_size - sliding_window_stride,)\n            ).to(device=data_device, dtype=data_dtype)\n            value[:, :, t: t_, :, :] += model_output * mask\n            weight[:, :, t: t_, :, :] += mask\n        value /= weight\n        model_kwargs.update(tensor_dict)\n        return value\n\n\ndef wantodance_get_single_freqs(freqs, frame_num, fps):\n    total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)\n    interval_frame = 30.0 / (fps + 1e-6)\n    freqs_0 = freqs[:total_frame]\n    freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)\n    freqs_new[0] = freqs_0[0]\n    freqs_new[-1] = freqs_0[total_frame - 1]\n    for i in range(1, frame_num-1):\n        pos = i * interval_frame\n        low_idx = int(pos)\n        high_idx = min(low_idx + 1, total_frame - 1)\n        weight_high = pos - low_idx\n        weight_low = 1.0 - weight_high\n        freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high\n    return freqs_new\n\n\ndef model_fn_wan_video(\n    dit: WanModel,\n    motion_controller: WanMotionControllerModel = None,\n    vace: VaceWanModel = None,\n    vap: MotWanModel = None,\n    animate_adapter: WanAnimateAdapter = None,\n    latents: torch.Tensor = None,\n    timestep: torch.Tensor = None,\n    context: torch.Tensor = None,\n    clip_feature: Optional[torch.Tensor] = None,\n    y: Optional[torch.Tensor] = None,\n    reference_latents = None,\n    vace_context = None,\n    vace_scale = 1.0,\n    audio_embeds: Optional[torch.Tensor] = None,\n    motion_latents: Optional[torch.Tensor] = None,\n    s2v_pose_latents: Optional[torch.Tensor] = None,\n    vap_hidden_state = None,\n    vap_clip_feature = None,\n    context_vap = None,\n    drop_motion_frames: bool = True,\n    tea_cache: TeaCache = None,\n    use_unified_sequence_parallel: bool = False,\n    motion_bucket_id: Optional[torch.Tensor] = None,\n    pose_latents=None,\n    face_pixel_values=None,\n    longcat_latents=None,\n    sliding_window_size: Optional[int] = None,\n    sliding_window_stride: Optional[int] = None,\n    cfg_merge: bool = False,\n    use_gradient_checkpointing: bool = False,\n    use_gradient_checkpointing_offload: bool = False,\n    control_camera_latents_input = None,\n    fuse_vae_embedding_in_latents: bool = False,\n    wantodance_refimage_feature = None,\n    wantodance_fps: float = 30.0,\n    music_feature = None,\n    skip_9th_layer: bool = False,\n    **kwargs,\n):\n    if sliding_window_size is not None and sliding_window_stride is not None:\n        model_kwargs = dict(\n            dit=dit,\n            motion_controller=motion_controller,\n            vace=vace,\n            latents=latents,\n            timestep=timestep,\n            context=context,\n            clip_feature=clip_feature,\n            y=y,\n            reference_latents=reference_latents,\n            vace_context=vace_context,\n            vace_scale=vace_scale,\n            tea_cache=tea_cache,\n            use_unified_sequence_parallel=use_unified_sequence_parallel,\n            motion_bucket_id=motion_bucket_id,\n        )\n        return TemporalTiler_BCTHW().run(\n            model_fn_wan_video,\n            sliding_window_size, sliding_window_stride,\n            latents.device, latents.dtype,\n            model_kwargs=model_kwargs,\n            tensor_names=[\"latents\", \"y\"],\n            batch_size=2 if cfg_merge else 1\n        )\n    # LongCat-Video\n    if isinstance(dit, LongCatVideoTransformer3DModel):\n        return model_fn_longcat_video(\n            dit=dit,\n            latents=latents,\n            timestep=timestep,\n            context=context,\n            longcat_latents=longcat_latents,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n        \n    # wan2.2 s2v\n    if audio_embeds is not None:\n        return model_fn_wans2v(\n            dit=dit,\n            latents=latents,\n            timestep=timestep,\n            context=context,\n            audio_embeds=audio_embeds,\n            motion_latents=motion_latents,\n            s2v_pose_latents=s2v_pose_latents,\n            drop_motion_frames=drop_motion_frames,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_unified_sequence_parallel=use_unified_sequence_parallel,\n        )\n\n    if use_unified_sequence_parallel:\n        import torch.distributed as dist\n        from xfuser.core.distributed import (get_sequence_parallel_rank,\n                                            get_sequence_parallel_world_size,\n                                            get_sp_group)\n\n    # Timestep\n    if dit.seperated_timestep and fuse_vae_embedding_in_latents:\n        timestep = torch.concat([\n            torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),\n            torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep\n        ]).flatten()\n        t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))\n        if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:\n            t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)\n            t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]\n            t = t_chunks[get_sequence_parallel_rank()]\n        t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))\n    else:\n        t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))\n        t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))\n    \n    # Motion Controller\n    if motion_bucket_id is not None and motion_controller is not None:\n        t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))\n    context = dit.text_embedding(context)\n\n    x = latents\n    # Merged cfg\n    if x.shape[0] != context.shape[0]:\n        x = torch.concat([x] * context.shape[0], dim=0)\n    if timestep.shape[0] != context.shape[0]:\n        timestep = torch.concat([timestep] * context.shape[0], dim=0)\n\n    # Image Embedding\n    if y is not None and dit.require_vae_embedding:\n        x = torch.cat([x, y], dim=1)\n    if clip_feature is not None and dit.require_clip_embedding:\n        clip_embdding = dit.img_emb(clip_feature)\n        context = torch.cat([clip_embdding, context], dim=1)\n        \n    # Camera control\n    if hasattr(dit, \"wantodance_enable_global\") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:\n        x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)\n    else:\n        x = dit.patchify(x, control_camera_latents_input)\n    \n    # Animate\n    if pose_latents is not None and face_pixel_values is not None:\n        x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)\n    \n    # Patchify\n    f, h, w = x.shape[2:]\n    x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()\n    \n    # Reference image\n    if reference_latents is not None:\n        if len(reference_latents.shape) == 5:\n            reference_latents = reference_latents[:, :, 0]\n        reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)\n        x = torch.concat([reference_latents, x], dim=1)\n        f += 1\n    \n    freqs = torch.cat([\n        dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n        dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n        dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n    ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)\n\n    # VAP \n    if vap is not None:\n        # hidden state\n        x_vap = vap_hidden_state\n        x_vap = vap.patchify(x_vap)\n        x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()\n        # Timestep\n        clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)\n        t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))\n        t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))\n\n        # rope\n        freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)\n\n        # context\n        vap_clip_embedding = vap.img_emb(vap_clip_feature)\n        context_vap = vap.text_embedding(context_vap)\n        context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)\n    \n    # TeaCache\n    if tea_cache is not None:\n        tea_cache_update = tea_cache.check(dit, x, t_mod)\n    else:\n        tea_cache_update = False\n        \n    if vace_context is not None:\n        vace_hints = vace(\n            x, vace_context, context, t_mod, freqs,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload\n        )\n\n    # WanToDance\n    if hasattr(dit, \"wantodance_enable_global\") and dit.wantodance_enable_global:\n        if wantodance_refimage_feature is not None:\n            refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)\n            context = torch.cat([refimage_feature_embedding, context], dim=1)\n        if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30: \n            freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)\n            freqs = torch.cat([\n                freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),\n                dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n                dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n            ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)\n        if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:\n            if use_unified_sequence_parallel:\n                length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()\n                music_feature = music_feature[:length]\n                music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]\n            if not dit.training:\n                dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation\n            music_feature = music_feature.to(x.device, dtype=x.dtype)\n            music_feature = dit.music_projection(music_feature)\n            music_feature = dit.music_encoder(music_feature)\n            if music_feature.dim() == 2:\n                music_feature = music_feature.unsqueeze(0)\n            if use_unified_sequence_parallel:\n                if dist.is_initialized() and dist.get_world_size() > 1:\n                    music_feature = get_sp_group().all_gather(music_feature, dim=1)\n            music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]\n            N = 149 \n            M = 4800 \n            music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear') \n            music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]\n        if music_feature is not None:\n            if music_feature.dim() == 2:\n                music_feature = music_feature.unsqueeze(0)\n            music_feature = music_feature.to(x.device, dtype=x.dtype)\n            interp_mode = 'bilinear'\n            if interp_mode == 'bilinear':\n                frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21\n                context_shape_end = context.shape[2] ## 14B 5120\n                music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]\n                if use_unified_sequence_parallel:\n                    N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()\n                else:\n                    N = frame_num * 8\n                music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear') \n                music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]\n                if use_unified_sequence_parallel:\n                    dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]\n                else:\n                    dit.merged_audio_emb = music_feature\n            else: \n                dit.merged_audio_emb = music_feature\n\n    # blocks\n    if use_unified_sequence_parallel:\n        if dist.is_initialized() and dist.get_world_size() > 1:\n            chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)\n            pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]\n            chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]\n            x = chunks[get_sequence_parallel_rank()]\n    if tea_cache_update:\n        x = tea_cache.update(x)\n    else:\n        def create_custom_forward_vap(block, vap):\n            def custom_forward(*inputs):\n                return vap(block, *inputs)\n            return custom_forward\n        \n        # Block\n        for block_id, block in enumerate(dit.blocks):\n            if skip_9th_layer:\n                # This is only used in WanToDance\n                if block_id == 9:\n                    continue\n            if vap is not None and block_id in vap.mot_layers_mapping:\n                if use_gradient_checkpointing_offload:\n                    with torch.autograd.graph.save_on_cpu():\n                        x, x_vap = torch.utils.checkpoint.checkpoint(\n                            create_custom_forward_vap(block, vap),\n                            x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,\n                            use_reentrant=False\n                        )\n                elif use_gradient_checkpointing:\n                    x, x_vap = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward_vap(block, vap),\n                        x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,\n                        use_reentrant=False\n                    )\n                else:\n                    x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)\n            else:\n                x = gradient_checkpoint_forward(\n                    block,\n                    use_gradient_checkpointing,\n                    use_gradient_checkpointing_offload,\n                    x, context, t_mod, freqs\n                )\n              \n            \n            # VACE\n            if vace_context is not None and block_id in vace.vace_layers_mapping:\n                current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]\n                if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:\n                    current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]\n                    current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)\n                x = x + current_vace_hint * vace_scale\n            \n            # Animate\n            if pose_latents is not None and face_pixel_values is not None:\n                x = animate_adapter.after_transformer_block(block_id, x, motion_vec)\n            \n            # WanToDance\n            if hasattr(dit, \"wantodance_enable_music_inject\") and dit.wantodance_enable_music_inject:\n                x = dit.wantodance_after_transformer_block(block_id, x)\n        if tea_cache is not None:\n            tea_cache.store(x)\n            \n    if hasattr(dit, \"wantodance_enable_unimodel\") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:\n        x = dit.head_global(x, t)\n    else:\n        x = dit.head(x, t)\n    \n    if use_unified_sequence_parallel:\n        if dist.is_initialized() and dist.get_world_size() > 1:\n            x = get_sp_group().all_gather(x, dim=1)\n            x = x[:, :-pad_shape] if pad_shape > 0 else x\n    # Remove reference latents\n    if reference_latents is not None:\n        x = x[:, reference_latents.shape[1]:]\n        f -= 1\n    x = dit.unpatchify(x, (f, h, w))\n    return x\n\n\ndef model_fn_longcat_video(\n    dit: LongCatVideoTransformer3DModel,\n    latents: torch.Tensor = None,\n    timestep: torch.Tensor = None,\n    context: torch.Tensor = None,\n    longcat_latents: torch.Tensor = None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n):\n    if longcat_latents is not None:\n        latents[:, :, :longcat_latents.shape[2]] = longcat_latents\n        num_cond_latents = longcat_latents.shape[2]\n    else:\n        num_cond_latents = 0\n    context = context.unsqueeze(0)\n    encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64)\n    output = dit(\n        latents,\n        timestep,\n        context,\n        encoder_attention_mask,\n        num_cond_latents=num_cond_latents,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    output = -output\n    output = output.to(latents.dtype)\n    return output\n\n\ndef model_fn_wans2v(\n    dit,\n    latents,\n    timestep,\n    context,\n    audio_embeds,\n    motion_latents,\n    s2v_pose_latents,\n    drop_motion_frames=True,\n    use_gradient_checkpointing_offload=False,\n    use_gradient_checkpointing=False,\n    use_unified_sequence_parallel=False,\n):\n    if use_unified_sequence_parallel:\n        import torch.distributed as dist\n        from xfuser.core.distributed import (get_sequence_parallel_rank,\n                                            get_sequence_parallel_world_size,\n                                            get_sp_group)\n    origin_ref_latents = latents[:, :, 0:1]\n    x = latents[:, :, 1:]\n\n    # context embedding\n    context = dit.text_embedding(context)\n\n    # audio encode\n    audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)\n\n    # x and s2v_pose_latents\n    s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents\n    x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))\n    seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel\n\n    # reference image\n    ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))\n    grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))\n    x = torch.cat([x, ref_latents], dim=1)\n    # mask\n    mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)\n    # freqs\n    pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)\n    # motion\n    x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)\n\n    x = x + dit.trainable_cond_mask(mask).to(x.dtype)\n\n    # tmod\n    timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])\n    t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))\n    t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)\n\n    if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:\n        world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()\n        assert x.shape[1] % world_size == 0, f\"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}\"\n        x = torch.chunk(x, world_size, dim=1)[sp_rank]\n        seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())\n        seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]\n        seq_len_x = seq_len_x_list[sp_rank]\n\n    def create_custom_forward(module):\n        def custom_forward(*inputs):\n            return module(*inputs)\n        return custom_forward\n\n    for block_id, block in enumerate(dit.blocks):\n        x = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                x, context, t_mod, seq_len_x, pre_compute_freqs[0]\n            )\n        x = gradient_checkpoint_forward(\n            lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),\n            use_gradient_checkpointing,\n            use_gradient_checkpointing_offload,\n            x\n        )\n\n    if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:\n        x = get_sp_group().all_gather(x, dim=1)\n\n    x = x[:, :seq_len_x_global]\n    x = dit.head(x, t[:-1])\n    x = dit.unpatchify(x, (f, h, w))\n    # make compatible with wan video\n    x = torch.cat([origin_ref_latents, x], dim=2)\n    return x\n"
  },
  {
    "path": "diffsynth/pipelines/z_image.py",
    "content": "import torch, math, warnings\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange\nimport numpy as np\nfrom typing import Union, List, Optional, Tuple, Iterable, Dict\n\nfrom ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig, gradient_checkpoint_forward\nfrom ..core.data.operators import ImageCropAndResize\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput\nfrom ..utils.lora import merge_lora\n\nfrom transformers import AutoTokenizer\nfrom ..models.z_image_text_encoder import ZImageTextEncoder\nfrom ..models.z_image_dit import ZImageDiT\nfrom ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder\nfrom ..models.siglip2_image_encoder import Siglip2ImageEncoder428M\nfrom ..models.z_image_controlnet import ZImageControlNet\nfrom ..models.siglip2_image_encoder import Siglip2ImageEncoder\nfrom ..models.dinov3_image_encoder import DINOv3ImageEncoder\nfrom ..models.z_image_image2lora import ZImageImage2LoRAModel\n\n\nclass ZImagePipeline(BasePipeline):\n\n    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"Z-Image\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: ZImageDiT = None\n        self.vae_encoder: FluxVAEEncoder = None\n        self.vae_decoder: FluxVAEDecoder = None\n        self.image_encoder: Siglip2ImageEncoder428M = None\n        self.controlnet: ZImageControlNet = None\n        self.siglip2_image_encoder: Siglip2ImageEncoder = None\n        self.dinov3_image_encoder: DINOv3ImageEncoder = None\n        self.image2lora_style: ZImageImage2LoRAModel = None\n        self.tokenizer: AutoTokenizer = None\n        self.in_iteration_models = (\"dit\", \"controlnet\")\n        self.units = [\n            ZImageUnit_ShapeChecker(),\n            ZImageUnit_PromptEmbedder(),\n            ZImageUnit_NoiseInitializer(),\n            ZImageUnit_InputImageEmbedder(),\n            ZImageUnit_EditImageAutoResize(),\n            ZImageUnit_EditImageEmbedderVAE(),\n            ZImageUnit_EditImageEmbedderSiglip(),\n            ZImageUnit_PAIControlNet(),\n        ]\n        self.model_fn = model_fn_z_image\n    \n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = get_device_type(),\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n        vram_limit: float = None,\n        enable_npu_patch: bool = True,\n    ):\n        # Initialize pipeline\n        pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"z_image_dit\")\n        pipe.vae_encoder = model_pool.fetch_model(\"flux_vae_encoder\")\n        pipe.vae_decoder = model_pool.fetch_model(\"flux_vae_decoder\")\n        pipe.image_encoder = model_pool.fetch_model(\"siglip_vision_model_428m\")\n        pipe.controlnet = model_pool.fetch_model(\"z_image_controlnet\")\n        pipe.siglip2_image_encoder = model_pool.fetch_model(\"siglip2_image_encoder\")\n        pipe.dinov3_image_encoder = model_pool.fetch_model(\"dinov3_image_encoder\")\n        pipe.image2lora_style = model_pool.fetch_model(\"z_image_image2lora_style\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        # NPU patch\n        apply_npu_patch(enable_npu_patch)\n        return pipe\n    \n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Edit\n        edit_image: Image.Image = None,\n        edit_image_auto_resize: bool = True,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 8,\n        sigma_shift: float = None,\n        # ControlNet\n        controlnet_inputs: List[ControlNetInput] = None,\n        # Image to LoRA\n        image2lora_images: List[Image.Image] = None,\n        positive_only_lora: Dict[str, torch.Tensor] = None,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)\n        \n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n            \"edit_image\": edit_image, \"edit_image_auto_resize\": edit_image_auto_resize,\n            \"controlnet_inputs\": controlnet_inputs,\n            \"image2lora_images\": image2lora_images, \"positive_only_lora\": positive_only_lora,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae_decoder'])\n        image = self.vae_decoder(inputs_shared[\"latents\"])\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass ZImageUnit_ShapeChecker(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\"),\n            output_params=(\"height\", \"width\"),\n        )\n\n    def process(self, pipe: ZImagePipeline, height, width):\n        height, width = pipe.check_resize_height_width(height, width)\n        return {\"height\": height, \"width\": width}\n\n\nclass ZImageUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params=(\"edit_image\",),\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_embeds\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def encode_prompt(\n        self,\n        pipe,\n        prompt: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        max_sequence_length: int = 512,\n    ) -> List[torch.FloatTensor]:\n        if isinstance(prompt, str):\n            prompt = [prompt]\n\n        for i, prompt_item in enumerate(prompt):\n            messages = [\n                {\"role\": \"user\", \"content\": prompt_item},\n            ]\n            prompt_item = pipe.tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=True,\n            )\n            prompt[i] = prompt_item\n\n        text_inputs = pipe.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids.to(device)\n        prompt_masks = text_inputs.attention_mask.to(device).bool()\n\n        prompt_embeds = pipe.text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_masks,\n            output_hidden_states=True,\n        ).hidden_states[-2]\n\n        embeddings_list = []\n\n        for i in range(len(prompt_embeds)):\n            embeddings_list.append(prompt_embeds[i][prompt_masks[i]])\n\n        return embeddings_list\n    \n    def encode_prompt_omni(\n        self,\n        pipe,\n        prompt: Union[str, List[str]],\n        edit_image=None,\n        device: Optional[torch.device] = None,\n        max_sequence_length: int = 512,\n    ) -> List[torch.FloatTensor]:\n        if isinstance(prompt, str):\n            prompt = [prompt]\n\n        if edit_image is None:\n            num_condition_images = 0\n        elif isinstance(edit_image, list):\n            num_condition_images = len(edit_image)\n        else:\n            num_condition_images = 1\n\n        for i, prompt_item in enumerate(prompt):\n            if num_condition_images == 0:\n                prompt[i] = [\"<|im_start|>user\\n\" + prompt_item + \"<|im_end|>\\n<|im_start|>assistant\\n\"]\n            elif num_condition_images > 0:\n                prompt_list = [\"<|im_start|>user\\n<|vision_start|>\"]\n                prompt_list += [\"<|vision_end|><|vision_start|>\"] * (num_condition_images - 1)\n                prompt_list += [\"<|vision_end|>\" + prompt_item + \"<|im_end|>\\n<|im_start|>assistant\\n<|vision_start|>\"]\n                prompt_list += [\"<|vision_end|><|im_end|>\"]\n                prompt[i] = prompt_list\n\n        flattened_prompt = []\n        prompt_list_lengths = []\n\n        for i in range(len(prompt)):\n            prompt_list_lengths.append(len(prompt[i]))\n            flattened_prompt.extend(prompt[i])\n\n        text_inputs = pipe.tokenizer(\n            flattened_prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids.to(device)\n        prompt_masks = text_inputs.attention_mask.to(device).bool()\n\n        prompt_embeds = pipe.text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_masks,\n            output_hidden_states=True,\n        ).hidden_states[-2]\n\n        embeddings_list = []\n        start_idx = 0\n        for i in range(len(prompt_list_lengths)):\n            batch_embeddings = []\n            end_idx = start_idx + prompt_list_lengths[i]\n            for j in range(start_idx, end_idx):\n                batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])\n            embeddings_list.append(batch_embeddings)\n            start_idx = end_idx\n\n        return embeddings_list\n\n    def process(self, pipe: ZImagePipeline, prompt, edit_image):\n        pipe.load_models_to_device(self.onload_model_names)\n        if hasattr(pipe, \"dit\") and pipe.dit is not None and pipe.dit.siglip_embedder is not None:\n            # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.\n            # We determine which encoding method to use based on the model architecture.\n            # If you are using two-stage split training,\n            # please use `--offload_models` instead of skipping the DiT model loading.\n            prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)\n        else:\n            prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)\n        return {\"prompt_embeds\": prompt_embeds}\n\n\nclass ZImageUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\nclass ZImageUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: ZImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae_encoder(image)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\nclass ZImageUnit_EditImageAutoResize(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"edit_image\", \"edit_image_auto_resize\"),\n            output_params=(\"edit_image\",),\n        )\n\n    def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):\n        if edit_image is None:\n            return {}\n        if edit_image_auto_resize is None or not edit_image_auto_resize:\n            return {}\n        operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)\n        if not isinstance(edit_image, list):\n            edit_image = [edit_image]\n        edit_image = [operator(i) for i in edit_image]\n        return {\"edit_image\": edit_image}\n\n\nclass ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"edit_image\",),\n            output_params=(\"image_embeds\",),\n            onload_model_names=(\"image_encoder\",)\n        )\n\n    def process(self, pipe: ZImagePipeline, edit_image):\n        if edit_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        if not isinstance(edit_image, list):\n            edit_image = [edit_image]\n        image_emb = []\n        for image_ in edit_image:\n            image_emb.append(pipe.image_encoder(image_, device=pipe.device))\n        return {\"image_embeds\": image_emb}\n\n\nclass ZImageUnit_EditImageEmbedderVAE(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"edit_image\",),\n            output_params=(\"image_latents\",),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: ZImagePipeline, edit_image):\n        if edit_image is None:\n            return {}\n        pipe.load_models_to_device(self.onload_model_names)\n        if not isinstance(edit_image, list):\n            edit_image = [edit_image]\n        image_latents = []\n        for image_ in edit_image:\n            image_ = pipe.preprocess_image(image_)\n            image_latents.append(pipe.vae_encoder(image_))\n        return {\"image_latents\": image_latents}\n\n\nclass ZImageUnit_PAIControlNet(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"controlnet_inputs\", \"height\", \"width\"),\n            output_params=(\"control_context\", \"control_scale\"),\n            onload_model_names=(\"vae_encoder\",)\n        )\n\n    def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):\n        if controlnet_inputs is None:\n            return {}\n        if len(controlnet_inputs) != 1:\n            print(\"Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.\")\n        controlnet_input = controlnet_inputs[0]\n        pipe.load_models_to_device(self.onload_model_names)\n\n        control_image = controlnet_input.image\n        if control_image is not None:\n            control_image = pipe.preprocess_image(control_image)\n            control_latents = pipe.vae_encoder(control_image)\n        else:\n            control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1\n        \n        inpaint_mask = controlnet_input.inpaint_mask\n        if inpaint_mask is not None:\n            inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)\n            inpaint_image = controlnet_input.inpaint_image\n            inpaint_image = pipe.preprocess_image(inpaint_image)\n            inpaint_image = inpaint_image * (inpaint_mask < 0.5)\n            inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]\n        else:\n            inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)\n            inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)\n        inpaint_latent = pipe.vae_encoder(inpaint_image)\n\n        control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)\n        control_context = rearrange(control_context, \"B C H W -> B C 1 H W\")\n        return {\"control_context\": control_context, \"control_scale\": controlnet_input.scale}\n\n\ndef model_fn_z_image(\n    dit: ZImageDiT,\n    controlnet: ZImageControlNet = None,\n    latents=None,\n    timestep=None,\n    prompt_embeds=None,\n    image_embeds=None,\n    image_latents=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    # Due to the complex and verbose codebase of Z-Image,\n    # we are temporarily using this inelegant structure.\n    # We will refactor this part in the future (if time permits).\n    if dit.siglip_embedder is None:\n        return model_fn_z_image_turbo(\n            dit,\n            controlnet=controlnet,\n            latents=latents,\n            timestep=timestep,\n            prompt_embeds=prompt_embeds,\n            image_embeds=image_embeds,\n            image_latents=image_latents,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            **kwargs,\n        )\n    latents = [rearrange(latents, \"B C H W -> C B H W\")]\n    if dit.siglip_embedder is not None:\n        if image_latents is not None:\n            image_latents = [rearrange(image_latent, \"B C H W -> C B H W\") for image_latent in image_latents]\n            latents = [image_latents + latents]\n            image_noise_mask = [[0] * len(image_latents) + [1]]\n        else:\n            latents = [latents]\n            image_noise_mask = [[1]]\n        image_embeds = [image_embeds]\n    else:\n        image_noise_mask = None\n    timestep = (1000 - timestep) / 1000\n    model_output = dit(\n        latents,\n        timestep,\n        prompt_embeds,\n        siglip_feats=image_embeds,\n        image_noise_mask=image_noise_mask,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )[0]\n    model_output = -model_output\n    model_output = rearrange(model_output, \"C B H W -> B C H W\")\n    return model_output\n\n\nclass ZImageUnit_Image2LoRAEncode(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"image2lora_images\",),\n            output_params=(\"image2lora_x\",),\n            onload_model_names=(\"siglip2_image_encoder\", \"dinov3_image_encoder\",),\n        )\n        from ..core.data.operators import ImageCropAndResize\n        self.processor_highres = ImageCropAndResize(height=1024, width=1024)\n    \n    def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):\n        pipe.load_models_to_device([\"siglip2_image_encoder\"])\n        embs = []\n        for image in images:\n            image = self.processor_highres(image)\n            embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))\n        embs = torch.stack(embs)\n        return embs\n    \n    def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):\n        pipe.load_models_to_device([\"dinov3_image_encoder\"])\n        embs = []\n        for image in images:\n            image = self.processor_highres(image)\n            embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))\n        embs = torch.stack(embs)\n        return embs\n\n    def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):\n        if images is None:\n            return {}\n        if not isinstance(images, list):\n            images = [images]\n        embs_siglip2 = self.encode_images_using_siglip2(pipe, images)\n        embs_dinov3 = self.encode_images_using_dinov3(pipe, images)\n        x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)\n        return x\n\n    def process(self, pipe: ZImagePipeline, image2lora_images):\n        if image2lora_images is None:\n            return {}\n        x = self.encode_images(pipe, image2lora_images)\n        return {\"image2lora_x\": x}\n\n\nclass ZImageUnit_Image2LoRADecode(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"image2lora_x\",),\n            output_params=(\"lora\",),\n            onload_model_names=(\"image2lora_style\",),\n        )\n    \n    def process(self, pipe: ZImagePipeline, image2lora_x):\n        if image2lora_x is None:\n            return {}\n        loras = []\n        if pipe.image2lora_style is not None:\n            pipe.load_models_to_device([\"image2lora_style\"])\n            for x in image2lora_x:\n                loras.append(pipe.image2lora_style(x=x, residual=None))\n        lora = merge_lora(loras, alpha=1 / len(image2lora_x))\n        return {\"lora\": lora}\n\n\ndef model_fn_z_image_turbo(\n    dit: ZImageDiT,\n    controlnet: ZImageControlNet = None,\n    latents=None,\n    timestep=None,\n    prompt_embeds=None,\n    image_embeds=None,\n    image_latents=None,\n    control_context=None,\n    control_scale=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    while isinstance(prompt_embeds, list):\n        prompt_embeds = prompt_embeds[0]\n    while isinstance(latents, list):\n        latents = latents[0]\n    while isinstance(image_embeds, list):\n        image_embeds = image_embeds[0]\n\n    # Timestep\n    timestep = 1000 - timestep\n    t_noisy = dit.t_embedder(timestep)\n    t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)\n\n    # Patchify\n    latents = rearrange(latents, \"B C H W -> C B H W\")\n    x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])\n    x = x[0]\n    cap_feats = cap_feats[0]\n\n    # Noise refine\n    x = dit.all_x_embedder[\"2-1\"](x)\n    x[torch.cat(patch_metadata.get(\"x_pad_mask\"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)\n    x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get(\"x_pos_ids\"), dim=0))\n    x = rearrange(x, \"L C -> 1 L C\")\n    x_freqs_cis = rearrange(x_freqs_cis, \"L C -> 1 L C\")\n\n    if control_context is not None:\n        kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)\n        refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(\n            dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,\n            use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n    \n    for layer_id, layer in enumerate(dit.noise_refiner):\n        x = gradient_checkpoint_forward(\n            layer,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            x=x,\n            attn_mask=None,\n            freqs_cis=x_freqs_cis,\n            adaln_input=t_noisy,\n        )\n        if control_context is not None:\n            x = x + refiner_hints[layer_id] * control_scale\n\n    # Prompt refine\n    cap_feats = dit.cap_embedder(cap_feats)\n    cap_feats[torch.cat(patch_metadata.get(\"cap_pad_mask\"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)\n    cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get(\"cap_pos_ids\"), dim=0))\n    cap_feats = rearrange(cap_feats, \"L C -> 1 L C\")\n    cap_freqs_cis = rearrange(cap_freqs_cis, \"L C -> 1 L C\")\n    \n    for layer in dit.context_refiner:\n        cap_feats = gradient_checkpoint_forward(\n            layer,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            x=cap_feats,\n            attn_mask=None,\n            freqs_cis=cap_freqs_cis,\n        )\n\n    # Unified\n    unified = torch.cat([x, cap_feats], dim=1)\n    unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)\n\n    if control_context is not None:\n        kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)\n        hints = controlnet.forward_layers(\n            unified, cap_feats, control_context, control_context_item_seqlens, kwargs,\n            use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n        )\n\n    for layer_id, layer in enumerate(dit.layers):\n        unified = gradient_checkpoint_forward(\n            layer,\n            use_gradient_checkpointing=use_gradient_checkpointing,\n            use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n            x=unified,\n            attn_mask=None,\n            freqs_cis=unified_freqs_cis,\n            adaln_input=t_noisy,\n        )\n        if control_context is not None:\n            if layer_id in controlnet.control_layers_mapping:\n                unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale\n    \n    # Output\n    unified = dit.all_final_layer[\"2-1\"](unified, t_noisy)\n    x = dit.unpatchify([unified[0]], patch_metadata.get(\"x_size\"))[0]\n    x = rearrange(x, \"C B H W -> B C H W\")\n    x = -x\n    return x\n\n\ndef apply_npu_patch(enable_npu_patch: bool=True):\n    if IS_NPU_AVAILABLE and enable_npu_patch:\n        from ..models.general_modules import RMSNorm\n        from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm\n        from ..models.z_image_dit import Attention\n        from ..core.npu_patch.npu_fused_operator import (\n            rms_norm_forward_npu, \n            rms_norm_forward_transformers_npu,\n            rotary_emb_Zimage_npu\n        )\n        warnings.warn(\"Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.\")\n        RMSNorm.forward = rms_norm_forward_npu\n        Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu\n        Attention.apply_rotary_emb = rotary_emb_Zimage_npu\n"
  },
  {
    "path": "diffsynth/utils/controlnet/__init__.py",
    "content": "from .controlnet_input import ControlNetInput\nfrom .annotator import Annotator\n"
  },
  {
    "path": "diffsynth/utils/controlnet/annotator.py",
    "content": "from typing_extensions import Literal, TypeAlias\n\nfrom diffsynth.core.device.npu_compatible_device import get_device_type\n\nProcessor_id: TypeAlias = Literal[\n    \"canny\", \"depth\", \"softedge\", \"lineart\", \"lineart_anime\", \"openpose\", \"normal\", \"tile\", \"none\", \"inpaint\"\n]\n\nclass Annotator:\n    def __init__(self, processor_id: Processor_id, model_path=\"models/Annotators\", detect_resolution=None, device=get_device_type(), skip_processor=False):\n        if not skip_processor:\n            if processor_id == \"canny\":\n                from controlnet_aux.processor import CannyDetector\n                self.processor = CannyDetector()\n            elif processor_id == \"depth\":\n                from controlnet_aux.processor import MidasDetector\n                self.processor = MidasDetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"softedge\":\n                from controlnet_aux.processor import HEDdetector\n                self.processor = HEDdetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"lineart\":\n                from controlnet_aux.processor import LineartDetector\n                self.processor = LineartDetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"lineart_anime\":\n                from controlnet_aux.processor import LineartAnimeDetector\n                self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"openpose\":\n                from controlnet_aux.processor import OpenposeDetector\n                self.processor = OpenposeDetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"normal\":\n                from controlnet_aux.processor import NormalBaeDetector\n                self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)\n            elif processor_id == \"tile\" or processor_id == \"none\" or processor_id == \"inpaint\":\n                self.processor = None\n            else:\n                raise ValueError(f\"Unsupported processor_id: {processor_id}\")\n        else:\n            self.processor = None\n\n        self.processor_id = processor_id\n        self.detect_resolution = detect_resolution\n    \n    def to(self,device):\n        if hasattr(self.processor,\"model\") and hasattr(self.processor.model,\"to\"):\n\n            self.processor.model.to(device)\n\n    def __call__(self, image, mask=None):\n        width, height = image.size\n        if self.processor_id == \"openpose\":\n            kwargs = {\n                \"include_body\": True,\n                \"include_hand\": True,\n                \"include_face\": True\n            }\n        else:\n            kwargs = {}\n        if self.processor is not None:\n            detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)\n            image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)\n        image = image.resize((width, height))\n        return image\n\n"
  },
  {
    "path": "diffsynth/utils/controlnet/controlnet_input.py",
    "content": "from dataclasses import dataclass\nfrom PIL import Image\n\n\n@dataclass\nclass ControlNetInput:\n    controlnet_id: int = 0\n    scale: float = 1.0\n    start: float = 1.0\n    end: float = 0.0\n    image: Image.Image = None\n    inpaint_image: Image.Image = None\n    inpaint_mask: Image.Image = None\n    processor_id: str = None\n"
  },
  {
    "path": "diffsynth/utils/data/__init__.py",
    "content": "import imageio, os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nimport subprocess\nimport shutil\n\n\nclass LowMemoryVideo:\n    def __init__(self, file_name):\n        self.reader = imageio.get_reader(file_name)\n    \n    def __len__(self):\n        return self.reader.count_frames()\n\n    def __getitem__(self, item):\n        return Image.fromarray(np.array(self.reader.get_data(item))).convert(\"RGB\")\n\n    def __del__(self):\n        self.reader.close()\n\n\ndef split_file_name(file_name):\n    result = []\n    number = -1\n    for i in file_name:\n        if ord(i)>=ord(\"0\") and ord(i)<=ord(\"9\"):\n            if number == -1:\n                number = 0\n            number = number*10 + ord(i) - ord(\"0\")\n        else:\n            if number != -1:\n                result.append(number)\n                number = -1\n            result.append(i)\n    if number != -1:\n        result.append(number)\n    result = tuple(result)\n    return result\n\n\ndef search_for_images(folder):\n    file_list = [i for i in os.listdir(folder) if i.endswith(\".jpg\") or i.endswith(\".png\")]\n    file_list = [(split_file_name(file_name), file_name) for file_name in file_list]\n    file_list = [i[1] for i in sorted(file_list)]\n    file_list = [os.path.join(folder, i) for i in file_list]\n    return file_list\n\n\nclass LowMemoryImageFolder:\n    def __init__(self, folder, file_list=None):\n        if file_list is None:\n            self.file_list = search_for_images(folder)\n        else:\n            self.file_list = [os.path.join(folder, file_name) for file_name in file_list]\n    \n    def __len__(self):\n        return len(self.file_list)\n\n    def __getitem__(self, item):\n        return Image.open(self.file_list[item]).convert(\"RGB\")\n\n    def __del__(self):\n        pass\n\n\ndef crop_and_resize(image, height, width):\n    image = np.array(image)\n    image_height, image_width, _ = image.shape\n    if image_height / image_width < height / width:\n        croped_width = int(image_height / height * width)\n        left = (image_width - croped_width) // 2\n        image = image[:, left: left+croped_width]\n        image = Image.fromarray(image).resize((width, height))\n    else:\n        croped_height = int(image_width / width * height)\n        left = (image_height - croped_height) // 2\n        image = image[left: left+croped_height, :]\n        image = Image.fromarray(image).resize((width, height))\n    return image\n\n\nclass VideoData:\n    def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):\n        if video_file is not None:\n            self.data_type = \"video\"\n            self.data = LowMemoryVideo(video_file, **kwargs)\n        elif image_folder is not None:\n            self.data_type = \"images\"\n            self.data = LowMemoryImageFolder(image_folder, **kwargs)\n        else:\n            raise ValueError(\"Cannot open video or image folder\")\n        self.length = None\n        self.set_shape(height, width)\n\n    def raw_data(self):\n        frames = []\n        for i in range(self.__len__()):\n            frames.append(self.__getitem__(i))\n        return frames\n\n    def set_length(self, length):\n        self.length = length\n\n    def set_shape(self, height, width):\n        self.height = height\n        self.width = width\n\n    def __len__(self):\n        if self.length is None:\n            return len(self.data)\n        else:\n            return self.length\n\n    def shape(self):\n        if self.height is not None and self.width is not None:\n            return self.height, self.width\n        else:\n            width, height = self.__getitem__(0).size\n            return height, width\n\n    def __getitem__(self, item):\n        frame = self.data.__getitem__(item)\n        width, height = frame.size\n        if self.height is not None and self.width is not None:\n            if self.height != height or self.width != width:\n                frame = crop_and_resize(frame, self.height, self.width)\n        return frame\n\n    def __del__(self):\n        pass\n\n    def save_images(self, folder):\n        os.makedirs(folder, exist_ok=True)\n        for i in tqdm(range(self.__len__()), desc=\"Saving images\"):\n            frame = self.__getitem__(i)\n            frame.save(os.path.join(folder, f\"{i}.png\"))\n\n\ndef save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):\n    writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)\n    for frame in tqdm(frames, desc=\"Saving video\"):\n        frame = np.array(frame)\n        writer.append_data(frame)\n    writer.close()\n\ndef save_frames(frames, save_path):\n    os.makedirs(save_path, exist_ok=True)\n    for i, frame in enumerate(tqdm(frames, desc=\"Saving images\")):\n        frame.save(os.path.join(save_path, f\"{i}.png\"))\n\n\ndef merge_video_audio(video_path: str, audio_path: str):\n    # TODO: may need a in-python implementation to avoid subprocess dependency\n    \"\"\"\n    Merge the video and audio into a new video, with the duration set to the shorter of the two,\n    and overwrite the original video file.\n\n    Parameters:\n    video_path (str): Path to the original video file\n    audio_path (str): Path to the audio file\n    \"\"\"\n\n    # check\n    if not os.path.exists(video_path):\n        raise FileNotFoundError(f\"video file {video_path} does not exist\")\n    if not os.path.exists(audio_path):\n        raise FileNotFoundError(f\"audio file {audio_path} does not exist\")\n\n    base, ext = os.path.splitext(video_path)\n    temp_output = f\"{base}_temp{ext}\"\n\n    try:\n        # create ffmpeg command\n        command = [\n            'ffmpeg',\n            '-y',  # overwrite\n            '-i',\n            video_path,\n            '-i',\n            audio_path,\n            '-c:v',\n            'copy',  # copy video stream\n            '-c:a',\n            'aac',  # use AAC audio encoder\n            '-b:a',\n            '192k',  # set audio bitrate (optional)\n            '-map',\n            '0:v:0',  # select the first video stream\n            '-map',\n            '1:a:0',  # select the first audio stream\n            '-shortest',  # choose the shortest duration\n            temp_output\n        ]\n\n        # execute the command\n        result = subprocess.run(\n            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n\n        # check result\n        if result.returncode != 0:\n            error_msg = f\"FFmpeg execute failed: {result.stderr}\"\n            print(error_msg)\n            raise RuntimeError(error_msg)\n\n        shutil.move(temp_output, video_path)\n        print(f\"Merge completed, saved to {video_path}\")\n\n    except Exception as e:\n        if os.path.exists(temp_output):\n            os.remove(temp_output)\n        print(f\"merge_video_audio failed with error: {e}\")\n\n\ndef save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):\n    save_video(frames, save_path, fps, quality, ffmpeg_params)\n    merge_video_audio(save_path, audio_path)\n"
  },
  {
    "path": "diffsynth/utils/data/audio.py",
    "content": "import torch\nimport torchaudio\n\n\ndef convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert audio to mono by averaging channels.\n    Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].\n    \"\"\"\n    return audio_tensor.mean(dim=-2, keepdim=True)\n\n\ndef convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert audio to stereo.\n    Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.\n    \"\"\"\n    if audio_tensor.size(-2) == 1:\n        return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)\n    return audio_tensor\n\n\ndef resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:\n    \"\"\"Resample waveform to target sample rate if needed.\"\"\"\n    if source_rate == target_rate:\n        return waveform\n    resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)\n    return resampled.to(dtype=waveform.dtype)\n\n\ndef read_audio_with_torchcodec(\n    path: str,\n    start_time: float = 0,\n    duration: float | None = None,\n) -> tuple[torch.Tensor, int]:\n    \"\"\"\n    Read audio from file natively using torchcodec, with optional start time and duration.\n    \n    Args:\n        path (str): The file path to the audio file.\n        start_time (float, optional): The start time in seconds to read from. Defaults to 0.\n        duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.\n        \n    Returns:\n        tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.\n            The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.\n    \"\"\"\n    from torchcodec.decoders import AudioDecoder\n    decoder = AudioDecoder(path)\n    stop_seconds = None if duration is None else start_time + duration\n    waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data\n    return waveform, decoder.metadata.sample_rate\n\n\ndef read_audio(\n    path: str,\n    start_time: float = 0,\n    duration: float | None = None,\n    resample: bool = False,\n    resample_rate: int = 48000,\n    backend: str = \"torchcodec\",\n) -> tuple[torch.Tensor, int]:\n    \"\"\"\n    Read audio from file, with optional start time, duration, and resampling.\n    \n    Args:\n        path (str): The file path to the audio file.\n        start_time (float, optional): The start time in seconds to read from. Defaults to 0.\n        duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.\n        resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.\n        resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.\n        backend (str, optional): The audio backend to use for reading. Defaults to \"torchcodec\".\n        \n    Returns:\n        tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.\n            The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.\n    \"\"\"\n    if backend == \"torchcodec\":\n        waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)\n    else:\n        raise ValueError(f\"Unsupported audio backend: {backend}\")\n\n    if resample:\n        waveform = resample_waveform(waveform, sample_rate, resample_rate)\n        sample_rate = resample_rate\n\n    return waveform, sample_rate\n\n\ndef save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = \"torchcodec\"):\n    \"\"\"\n    Save audio tensor to file.\n    \n    Args:\n        waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].\n        sample_rate (int): The sample rate of the audio.\n        save_path (str): The file path to save the audio to.\n        backend (str, optional): The audio backend to use for saving. Defaults to \"torchcodec\".\n    \"\"\"\n    if waveform.dim() == 3:\n        waveform = waveform[0]\n\n    if backend == \"torchcodec\":\n        from torchcodec.encoders import AudioEncoder\n        encoder = AudioEncoder(waveform, sample_rate=sample_rate)\n        encoder.to_file(dest=save_path)\n    else:\n        raise ValueError(f\"Unsupported audio backend: {backend}\")\n"
  },
  {
    "path": "diffsynth/utils/data/audio_video.py",
    "content": "import av\nfrom fractions import Fraction\nimport torch\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom .audio import convert_to_stereo\n\n\ndef _resample_audio(\n    container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame\n) -> None:\n    cc = audio_stream.codec_context\n\n    # Use the encoder's format/layout/rate as the *target*\n    target_format = cc.format or \"fltp\"  # AAC → usually fltp\n    target_layout = cc.layout or \"stereo\"\n    target_rate = cc.sample_rate or frame_in.sample_rate\n\n    audio_resampler = av.audio.resampler.AudioResampler(\n        format=target_format,\n        layout=target_layout,\n        rate=target_rate,\n    )\n\n    audio_next_pts = 0\n    for rframe in audio_resampler.resample(frame_in):\n        if rframe.pts is None:\n            rframe.pts = audio_next_pts\n        audio_next_pts += rframe.samples\n        rframe.sample_rate = frame_in.sample_rate\n        container.mux(audio_stream.encode(rframe))\n\n    # flush audio encoder\n    for packet in audio_stream.encode():\n        container.mux(packet)\n\n\ndef _write_audio(\n    container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int\n) -> None:\n    if samples.ndim == 1:\n        samples = samples.unsqueeze(0)\n    samples = convert_to_stereo(samples)\n    assert samples.ndim == 2 and samples.shape[0] == 2, \"audio samples must be [C, S] or [S], C must be 1 or 2\"\n    samples = samples.T\n    # Convert to int16 packed for ingestion; resampler converts to encoder fmt.\n    if samples.dtype != torch.int16:\n        samples = torch.clip(samples, -1.0, 1.0)\n        samples = (samples * 32767.0).to(torch.int16)\n\n    frame_in = av.AudioFrame.from_ndarray(\n        samples.contiguous().reshape(1, -1).cpu().numpy(),\n        format=\"s16\",\n        layout=\"stereo\",\n    )\n    frame_in.sample_rate = audio_sample_rate\n\n    _resample_audio(container, audio_stream, frame_in)\n\n\ndef _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:\n    \"\"\"\n    Prepare the audio stream for writing.\n    \"\"\"\n    audio_stream = container.add_stream(\"aac\")\n    supported_sample_rates = audio_stream.codec_context.codec.audio_rates\n    if supported_sample_rates:\n        best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))\n        if best_rate != audio_sample_rate:\n            print(f\"Using closest supported audio sample rate: {best_rate}\")\n    else:\n        best_rate = audio_sample_rate\n    audio_stream.codec_context.sample_rate = best_rate\n    audio_stream.codec_context.layout = \"stereo\"\n    audio_stream.codec_context.time_base = Fraction(1, best_rate)\n    return audio_stream\n\n\ndef write_video_audio(\n    video: list[Image.Image],\n    audio: torch.Tensor | None,\n    output_path: str,\n    fps: int = 24,\n    audio_sample_rate: int | None = None,\n) -> None:\n    \"\"\"\n    Writes a sequence of images and an audio tensor to a video file.\n\n    This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream\n    and multiplex a PyTorch tensor as the audio stream into the output container.\n\n    Args:\n        video (list[Image.Image]): A list of PIL Image objects representing the video frames. \n            The length of this list determines the total duration of the video based on the FPS.\n        audio (torch.Tensor | None): The audio data as a PyTorch tensor.\n            The shape is typically (channels, samples). If no audio is required, pass None.\n            channels can be 1 or 2. 1 for mono, 2 for stereo.\n        output_path (str): The file path (including extension) where the output video will be saved.\n        fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.\n        audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.\n            If the audio tensor is provided and this is None, the function attempts to infer the rate \n            based on the audio tensor's length and the video duration.\n    Raises:\n        ValueError: If an audio tensor is provided but the sample rate cannot be determined.\n    \"\"\"\n    duration = len(video) / fps\n    if audio_sample_rate is None:\n        audio_sample_rate = int(audio.shape[-1] / duration)\n\n    width, height = video[0].size\n    container = av.open(output_path, mode=\"w\")\n    stream = container.add_stream(\"libx264\", rate=int(fps))\n    stream.width = width\n    stream.height = height\n    stream.pix_fmt = \"yuv420p\"\n\n    if audio is not None:\n        if audio_sample_rate is None:\n            raise ValueError(\"audio_sample_rate is required when audio is provided\")\n        audio_stream = _prepare_audio_stream(container, audio_sample_rate)\n\n    for frame in tqdm(video, total=len(video)):\n        frame = av.VideoFrame.from_image(frame)\n        for packet in stream.encode(frame):\n            container.mux(packet)\n\n    # Flush encoder\n    for packet in stream.encode():\n        container.mux(packet)\n\n    if audio is not None:\n        _write_audio(container, audio_stream, audio, audio_sample_rate)\n\n    container.close()\n"
  },
  {
    "path": "diffsynth/utils/data/media_io_ltx2.py",
    "content": "import av\nimport numpy as np\nfrom io import BytesIO\nfrom .audio_video import write_video_audio as write_video_audio_ltx2\n\n\ndef encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:\n    container = av.open(output_file, \"w\", format=\"mp4\")\n    try:\n        stream = container.add_stream(\"libx264\", rate=1, options={\"crf\": str(crf), \"preset\": \"veryfast\"})\n        # Round to nearest multiple of 2 for compatibility with video codecs\n        height = image_array.shape[0] // 2 * 2\n        width = image_array.shape[1] // 2 * 2\n        image_array = image_array[:height, :width]\n        stream.height = height\n        stream.width = width\n        av_frame = av.VideoFrame.from_ndarray(image_array, format=\"rgb24\").reformat(format=\"yuv420p\")\n        container.mux(stream.encode(av_frame))\n        container.mux(stream.encode())\n    finally:\n        container.close()\n\n\ndef decode_single_frame(video_file: str) -> np.array:\n    container = av.open(video_file)\n    try:\n        stream = next(s for s in container.streams if s.type == \"video\")\n        frame = next(container.decode(stream))\n    finally:\n        container.close()\n    return frame.to_ndarray(format=\"rgb24\")\n\n\ndef ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:\n    if crf == 0:\n        return image\n\n    with BytesIO() as output_file:\n        encode_single_frame(output_file, image, crf)\n        video_bytes = output_file.getvalue()\n    with BytesIO(video_bytes) as video_file:\n        image_array = decode_single_frame(video_file)\n    return image_array\n"
  },
  {
    "path": "diffsynth/utils/lora/__init__.py",
    "content": "from .general import GeneralLoRALoader\nfrom .merge import merge_lora\nfrom .reset_rank import reset_lora_rank"
  },
  {
    "path": "diffsynth/utils/lora/flux.py",
    "content": "from .general import GeneralLoRALoader\nimport torch, math\n\n\nclass FluxLoRALoader(GeneralLoRALoader):\n    def __init__(self, device=\"cpu\", torch_dtype=torch.float32):\n        super().__init__(device=device, torch_dtype=torch_dtype)\n    \n        self.diffusers_rename_dict = {\n            \"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight\":\"single_blocks.blockid.a_to_k.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight\":\"single_blocks.blockid.a_to_k.lora_B.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight\":\"single_blocks.blockid.a_to_q.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight\":\"single_blocks.blockid.a_to_q.lora_B.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight\":\"single_blocks.blockid.a_to_v.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight\":\"single_blocks.blockid.a_to_v.lora_B.weight\",\n            \"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight\":\"single_blocks.blockid.norm.linear.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight\":\"single_blocks.blockid.norm.linear.lora_B.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight\":\"single_blocks.blockid.proj_in_besides_attn.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight\":\"single_blocks.blockid.proj_in_besides_attn.lora_B.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight\":\"single_blocks.blockid.proj_out.lora_A.weight\",\n            \"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight\":\"single_blocks.blockid.proj_out.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_k.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_k.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_q.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_q.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight\":\"blocks.blockid.attn.b_to_v.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight\":\"blocks.blockid.attn.b_to_v.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight\":\"blocks.blockid.attn.b_to_out.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight\":\"blocks.blockid.attn.b_to_out.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight\":\"blocks.blockid.attn.a_to_k.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight\":\"blocks.blockid.attn.a_to_k.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight\":\"blocks.blockid.attn.a_to_out.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight\":\"blocks.blockid.attn.a_to_out.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight\":\"blocks.blockid.attn.a_to_q.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight\":\"blocks.blockid.attn.a_to_q.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight\":\"blocks.blockid.attn.a_to_v.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight\":\"blocks.blockid.attn.a_to_v.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight\":\"blocks.blockid.ff_a.0.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight\":\"blocks.blockid.ff_a.0.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight\":\"blocks.blockid.ff_a.2.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight\":\"blocks.blockid.ff_a.2.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight\":\"blocks.blockid.ff_b.0.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight\":\"blocks.blockid.ff_b.0.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight\":\"blocks.blockid.ff_b.2.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight\":\"blocks.blockid.ff_b.2.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight\":\"blocks.blockid.norm1_a.linear.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight\":\"blocks.blockid.norm1_a.linear.lora_B.weight\",\n            \"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight\":\"blocks.blockid.norm1_b.linear.lora_A.weight\",\n            \"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight\":\"blocks.blockid.norm1_b.linear.lora_B.weight\",\n        }\n\n        self.civitai_rename_dict = {\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_a.linear.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_a.linear.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_b.linear.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_b.linear.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.a_to_qkv.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.a_to_qkv.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.b_to_qkv.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.b_to_qkv.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight\": \"blocks.blockid.attn.a_to_out.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight\": \"blocks.blockid.attn.a_to_out.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight\": \"blocks.blockid.attn.b_to_out.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight\": \"blocks.blockid.attn.b_to_out.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight\": \"blocks.blockid.ff_a.0.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight\": \"blocks.blockid.ff_a.0.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight\": \"blocks.blockid.ff_a.2.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight\": \"blocks.blockid.ff_a.2.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight\": \"blocks.blockid.ff_b.0.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight\": \"blocks.blockid.ff_b.0.lora_B.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight\": \"blocks.blockid.ff_b.2.lora_A.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight\": \"blocks.blockid.ff_b.2.lora_B.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight\": \"single_blocks.blockid.norm.linear.lora_A.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight\": \"single_blocks.blockid.norm.linear.lora_B.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_down.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_A.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_up.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_B.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_down.weight\": \"single_blocks.blockid.proj_out.lora_A.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_up.weight\": \"single_blocks.blockid.proj_out.lora_B.weight\",\n        }\n\n    def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):\n        super().fuse_lora_to_base_model(model, state_dict_lora, alpha)\n    \n    def convert_state_dict(self, state_dict):\n\n        def guess_block_id(name,model_resource):\n            if model_resource == 'civitai':\n                names = name.split(\"_\")\n                for i in names:\n                    if i.isdigit():\n                        return i, name.replace(f\"_{i}_\", \"_blockid_\")\n            if model_resource == 'diffusers':\n                names = name.split(\".\")\n                for i in names:\n                    if i.isdigit():\n                        return i, name.replace(f\"transformer_blocks.{i}.\", \"transformer_blocks.blockid.\")\n            return None, None\n\n        def guess_resource(state_dict):\n            for k in state_dict:\n                if \"lora_unet_\" in k:\n                    return 'civitai'\n                elif k.startswith(\"transformer.\"):\n                    return 'diffusers'\n                else:\n                    None\n        \n        model_resource = guess_resource(state_dict)\n        if model_resource is None:\n            return state_dict\n\n        rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict\n        def guess_alpha(state_dict):\n                for name, param in state_dict.items():\n                    if \".alpha\" in name:\n                        for suffix in [\".lora_down.weight\", \".lora_A.weight\"]:\n                            name_ = name.replace(\".alpha\", suffix)\n                            if name_ in state_dict:\n                                lora_alpha = param.item() / state_dict[name_].shape[0]\n                                lora_alpha = math.sqrt(lora_alpha)\n                                return lora_alpha\n\n                return 1\n        \n        alpha = guess_alpha(state_dict)\n        \n        state_dict_ = {}\n        for name, param in state_dict.items():\n            block_id, source_name = guess_block_id(name,model_resource)\n            if alpha != 1:\n                param *= alpha\n            if source_name in rename_dict:\n                target_name = rename_dict[source_name]\n                target_name = target_name.replace(\".blockid.\", f\".{block_id}.\")\n                state_dict_[target_name] = param\n            else:\n                state_dict_[name] = param\n        \n        if model_resource == 'diffusers':\n            for name in list(state_dict_.keys()):\n                if \"single_blocks.\" in name and \".a_to_q.\" in name:\n                    mlp = state_dict_.get(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"), None)\n                    if mlp is None:\n                        dim = 4\n                        if 'lora_A' in name:\n                            dim = 1\n                        mlp = torch.zeros(dim * state_dict_[name].shape[0],\n                                        *state_dict_[name].shape[1:],\n                                        dtype=state_dict_[name].dtype)\n                    else:\n                        state_dict_.pop(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"))\n\n                    mlp = mlp.to(device=state_dict_[name].device)\n                    if 'lora_A' in name:\n                        param = torch.concat([\n                            state_dict_.pop(name),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\")),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\")),\n                            mlp,\n                        ], dim=0)\n                    elif 'lora_B' in name:\n                        d, r = state_dict_[name].shape\n                        param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)\n                        param[:d, :r] = state_dict_.pop(name)\n                        param[d:2*d, r:2*r] = state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\"))\n                        param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\"))\n                        param[3*d:, 3*r:] = mlp\n                    else:\n                        param = torch.concat([\n                            state_dict_.pop(name),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\")),\n                            state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\")),\n                            mlp,\n                        ], dim=0)\n                    name_ = name.replace(\".a_to_q.\", \".to_qkv_mlp.\")\n                    state_dict_[name_] = param\n            for name in list(state_dict_.keys()):\n                for component in [\"a\", \"b\"]:\n                    if f\".{component}_to_q.\" in name:\n                        name_ = name.replace(f\".{component}_to_q.\", f\".{component}_to_qkv.\")\n                        concat_dim = 0\n                        if 'lora_A' in name:\n                            param = torch.concat([\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                            ], dim=0)\n                        elif 'lora_B' in name:\n                            origin = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")]\n                            d, r = origin.shape\n                            # print(d, r)\n                            param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)\n                            param[:d, :r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")]\n                            param[d:2*d, r:2*r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")]\n                            param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")]\n                        else:\n                            param = torch.concat([\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                                state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                            ], dim=0)\n                        state_dict_[name_] = param\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\"))\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\"))\n                        state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\"))  \n        return state_dict_\n\n\nclass FluxLoRAConverter:\n    def __init__(self):\n        pass\n\n    @staticmethod\n    def align_to_opensource_format(state_dict, alpha=None):\n        prefix_rename_dict = {\n            \"single_blocks\": \"lora_unet_single_blocks\",\n            \"blocks\": \"lora_unet_double_blocks\",\n        }\n        middle_rename_dict = {\n            \"norm.linear\": \"modulation_lin\",\n            \"to_qkv_mlp\": \"linear1\",\n            \"proj_out\": \"linear2\",\n\n            \"norm1_a.linear\": \"img_mod_lin\",\n            \"norm1_b.linear\": \"txt_mod_lin\",\n            \"attn.a_to_qkv\": \"img_attn_qkv\",\n            \"attn.b_to_qkv\": \"txt_attn_qkv\",\n            \"attn.a_to_out\": \"img_attn_proj\",\n            \"attn.b_to_out\": \"txt_attn_proj\",\n            \"ff_a.0\": \"img_mlp_0\",\n            \"ff_a.2\": \"img_mlp_2\",\n            \"ff_b.0\": \"txt_mlp_0\",\n            \"ff_b.2\": \"txt_mlp_2\",\n        }\n        suffix_rename_dict = {\n            \"lora_B.weight\": \"lora_up.weight\",\n            \"lora_A.weight\": \"lora_down.weight\",\n        }\n        state_dict_ = {}\n        for name, param in state_dict.items():\n            names = name.split(\".\")\n            if names[-2] != \"lora_A\" and names[-2] != \"lora_B\":\n                names.pop(-2)\n            prefix = names[0]\n            middle = \".\".join(names[2:-2])\n            suffix = \".\".join(names[-2:])\n            block_id = names[1]\n            if middle not in middle_rename_dict:\n                continue\n            rename = prefix_rename_dict[prefix] + \"_\" + block_id + \"_\" + middle_rename_dict[middle] + \".\" + suffix_rename_dict[suffix]\n            state_dict_[rename] = param\n            if rename.endswith(\"lora_up.weight\"):\n                lora_alpha = alpha if alpha is not None else param.shape[-1]\n                state_dict_[rename.replace(\"lora_up.weight\", \"alpha\")] = torch.tensor((lora_alpha,))[0]\n        return state_dict_\n    \n    @staticmethod\n    def align_to_diffsynth_format(state_dict):\n        rename_dict = {\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_a.linear.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_a.linear.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight\": \"blocks.blockid.norm1_b.linear.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight\": \"blocks.blockid.norm1_b.linear.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.a_to_qkv.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.a_to_qkv.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight\": \"blocks.blockid.attn.b_to_qkv.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight\": \"blocks.blockid.attn.b_to_qkv.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight\": \"blocks.blockid.attn.a_to_out.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight\": \"blocks.blockid.attn.a_to_out.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight\": \"blocks.blockid.attn.b_to_out.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight\": \"blocks.blockid.attn.b_to_out.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight\": \"blocks.blockid.ff_a.0.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight\": \"blocks.blockid.ff_a.0.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight\": \"blocks.blockid.ff_a.2.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight\": \"blocks.blockid.ff_a.2.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight\": \"blocks.blockid.ff_b.0.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight\": \"blocks.blockid.ff_b.0.lora_B.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight\": \"blocks.blockid.ff_b.2.lora_A.default.weight\",\n            \"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight\": \"blocks.blockid.ff_b.2.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight\": \"single_blocks.blockid.norm.linear.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight\": \"single_blocks.blockid.norm.linear.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_down.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear1.lora_up.weight\": \"single_blocks.blockid.to_qkv_mlp.lora_B.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_down.weight\": \"single_blocks.blockid.proj_out.lora_A.default.weight\",\n            \"lora_unet_single_blocks_blockid_linear2.lora_up.weight\": \"single_blocks.blockid.proj_out.lora_B.default.weight\",\n        }\n        def guess_block_id(name):\n            names = name.split(\"_\")\n            for i in names:\n                if i.isdigit():\n                    return i, name.replace(f\"_{i}_\", \"_blockid_\")\n            return None, None\n        state_dict_ = {}\n        for name, param in state_dict.items():\n            block_id, source_name = guess_block_id(name)\n            if source_name in rename_dict:\n                target_name = rename_dict[source_name]\n                target_name = target_name.replace(\".blockid.\", f\".{block_id}.\")\n                state_dict_[target_name] = param\n            else:\n                state_dict_[name] = param\n        return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/lora/general.py",
    "content": "import torch, warnings\n\n\nclass GeneralLoRALoader:\n    def __init__(self, device=\"cpu\", torch_dtype=torch.float32):\n        self.device = device\n        self.torch_dtype = torch_dtype\n    \n    \n    def get_name_dict(self, lora_state_dict):\n        lora_name_dict = {}\n        for key in lora_state_dict:\n            if \".lora_up.\" in key:\n                lora_A_key = \"lora_down\"\n                lora_B_key = \"lora_up\"\n            else:\n                lora_A_key = \"lora_A\"\n                lora_B_key = \"lora_B\"\n            if lora_B_key not in key:\n                continue\n            keys = key.split(\".\")\n            if len(keys) > keys.index(lora_B_key) + 2:\n                keys.pop(keys.index(lora_B_key) + 1)\n            keys.pop(keys.index(lora_B_key))\n            if keys[0] == \"diffusion_model\":\n                keys.pop(0)\n            keys.pop(-1)\n            target_name = \".\".join(keys)\n            # Alpha: Deprecated but retained for compatibility.\n            key_alpha = key.replace(lora_B_key + \".weight\", \"alpha\").replace(lora_B_key + \".default.weight\", \"alpha\")\n            if key_alpha == key or key_alpha not in lora_state_dict:\n                key_alpha = None\n            lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)\n        return lora_name_dict\n    \n    \n    def convert_state_dict(self, state_dict, suffix=\".weight\"):\n        name_dict = self.get_name_dict(state_dict)\n        state_dict_ = {}\n        for name in name_dict:\n            weight_up = state_dict[name_dict[name][0]]\n            weight_down = state_dict[name_dict[name][1]]\n            if name_dict[name][2] is not None:\n                warnings.warn(\"Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.\")\n                alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]\n                weight_down = weight_down * alpha\n            state_dict_[name + f\".lora_B{suffix}\"] = weight_up\n            state_dict_[name + f\".lora_A{suffix}\"] = weight_down\n        return state_dict_\n\n\n    def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict, alpha=1.0):\n        updated_num = 0\n        state_dict = self.convert_state_dict(state_dict)\n        lora_layer_names = set([i.replace(\".lora_B.weight\", \"\") for i in state_dict if i.endswith(\".lora_B.weight\")])\n        for name, module in model.named_modules():\n            if name in lora_layer_names:\n                weight_up = state_dict[name + \".lora_B.weight\"].to(device=self.device, dtype=self.torch_dtype)\n                weight_down = state_dict[name + \".lora_A.weight\"].to(device=self.device, dtype=self.torch_dtype)\n                if len(weight_up.shape) == 4:\n                    weight_up = weight_up.squeeze(3).squeeze(2)\n                    weight_down = weight_down.squeeze(3).squeeze(2)\n                    weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)\n                else:\n                    weight_lora = alpha * torch.mm(weight_up, weight_down)\n                state_dict_base = module.state_dict()\n                state_dict_base[\"weight\"] = state_dict_base[\"weight\"].to(device=self.device, dtype=self.torch_dtype) + weight_lora\n                module.load_state_dict(state_dict_base)\n                updated_num += 1\n        print(f\"{updated_num} tensors are fused by LoRA. Fused LoRA layers cannot be cleared by `pipe.clear_lora()`.\")\n"
  },
  {
    "path": "diffsynth/utils/lora/merge.py",
    "content": "import torch\nfrom typing import Dict, List\n\n\ndef merge_lora_weight(tensors_A, tensors_B):\n    lora_A = torch.concat(tensors_A, dim=0)\n    lora_B = torch.concat(tensors_B, dim=1)\n    return lora_A, lora_B\n\n\ndef merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1):\n    lora_merged = {}\n    keys = [i for i in loras[0].keys() if \".lora_A.\" in i]\n    for key in keys:\n        tensors_A = [lora[key] for lora in loras]\n        tensors_B = [lora[key.replace(\".lora_A.\", \".lora_B.\")] for lora in loras]\n        lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)\n        lora_merged[key] = lora_A * alpha\n        lora_merged[key.replace(\".lora_A.\", \".lora_B.\")] = lora_B\n    return lora_merged\n"
  },
  {
    "path": "diffsynth/utils/lora/reset_rank.py",
    "content": "import torch\n\ndef decomposite(tensor_A, tensor_B, rank):\n    dtype, device = tensor_A.dtype, tensor_A.device\n    weight = tensor_B @ tensor_A\n    U, S, V = torch.pca_lowrank(weight.float(), q=rank)\n    tensor_A = (V.T).to(dtype=dtype, device=device).contiguous()\n    tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous()\n    return tensor_A, tensor_B\n\ndef reset_lora_rank(lora, rank):\n    lora_merged = {}\n    keys = [i for i in lora.keys() if \".lora_A.\" in i]\n    for key in keys:\n        tensor_A = lora[key]\n        tensor_B = lora[key.replace(\".lora_A.\", \".lora_B.\")]\n        tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank)\n        lora_merged[key] = tensor_A\n        lora_merged[key.replace(\".lora_A.\", \".lora_B.\")] = tensor_B\n    return lora_merged"
  },
  {
    "path": "diffsynth/utils/ses/README.md",
    "content": "Please see `docs/en/Research_Tutorial/inference_time_scaling.md` or `docs/zh/Research_Tutorial/inference_time_scaling.md` for more details.\n"
  },
  {
    "path": "diffsynth/utils/ses/__init__.py",
    "content": "from .ses import ses_search"
  },
  {
    "path": "diffsynth/utils/ses/ses.py",
    "content": "import torch\nimport pywt\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef split_dwt(z_tensor_cpu, wavelet_name, dwt_level):\n    all_clow_np = []\n    all_chigh_list = []\n    z_tensor_cpu = z_tensor_cpu.cpu().float()\n    \n    for i in range(z_tensor_cpu.shape[0]): \n        z_numpy_ch = z_tensor_cpu[i].numpy()\n        \n        coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1))\n        \n        clow_np = coeffs_ch[0]\n        chigh_list = coeffs_ch[1:]\n        \n        all_clow_np.append(clow_np)\n        all_chigh_list.append(chigh_list)\n        \n    all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0))\n    return all_clow_tensor, all_chigh_list\n\n\ndef reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape):\n    H_high, W_high = original_shape\n    c_low_tensor_cpu = c_low_tensor_cpu.cpu().float()\n    \n    clow_np = c_low_tensor_cpu.numpy()\n    \n    if clow_np.ndim == 4 and clow_np.shape[0] == 1:\n        clow_np = clow_np[0]\n\n    coeffs_combined = [clow_np] + c_high_coeffs\n    z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1))\n    if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high:\n        z_recon_np = z_recon_np[..., :H_high, :W_high]\n    z_recon_tensor = torch.from_numpy(z_recon_np)\n    if z_recon_tensor.ndim == 3:\n        z_recon_tensor = z_recon_tensor.unsqueeze(0)\n    return z_recon_tensor\n\n\ndef ses_search(\n    base_latents,\n    objective_reward_fn,\n    total_eval_budget=30,\n    popsize=10,\n    k_elites=5,\n    wavelet_name=\"db1\",\n    dwt_level=4,\n):\n    latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]\n    c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)\n    c_high_fixed = c_high_fixed_batch[0]    \n    c_low_shape = c_low_init.shape[1:]\n    mu = torch.zeros_like(c_low_init.view(-1).cpu()) \n    sigma_sq = torch.ones_like(mu) * 1.0 \n    \n    best_overall = {\"fitness\": -float('inf'), \"score\": -float('inf'), \"c_low\": c_low_init[0]}\n    eval_count = 0\n    \n    elite_db = []    \n    n_generations = (total_eval_budget // popsize) + 5\n    pbar = tqdm(total=total_eval_budget, desc=\"[SES] Searching\", unit=\"img\")\n\n    for gen in range(n_generations):\n        if eval_count >= total_eval_budget: break\n        \n        std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9))\n        z_noise = torch.randn(popsize, mu.shape[0])\n        samples_flat = mu + z_noise * std\n        samples_reshaped = samples_flat.view(popsize, *c_low_shape) \n        \n        batch_results = []\n        \n        for i in range(popsize):\n            if eval_count >= total_eval_budget: break\n            \n            c_low_sample = samples_reshaped[i].unsqueeze(0) \n            z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w))\n            z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype)  \n            # img = pipeline_callback(z_recon)\n\n            # score = scorer.get_score(img, prompt)\n            score = objective_reward_fn(z_recon)\n            res = {\n                \"score\": score, \n                \"c_low\": c_low_sample.cpu()\n            }\n            batch_results.append(res)\n            if score > best_overall['score']:\n                best_overall = res\n                \n            eval_count += 1\n            pbar.update(1)\n            \n        if not batch_results: break\n        elite_db.extend(batch_results)        \n        elite_db.sort(key=lambda x: x['score'], reverse=True)        \n        elite_db = elite_db[:k_elites]        \n        elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])\n        mu_new = torch.mean(elites_flat, dim=0)\n        \n        if len(elite_db) > 1:\n            sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7\n        else:\n            sigma_sq_new = sigma_sq\n        mu = mu_new\n        sigma_sq = sigma_sq_new\n    pbar.close()\n    best_c_low = best_overall['c_low']\n    final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w))\n    \n    return final_latents.to(base_latents.device, dtype=base_latents.dtype)\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/__init__.py",
    "content": ""
  },
  {
    "path": "diffsynth/utils/state_dict_converters/anima_dit.py",
    "content": "def AnimaDiTStateDictConverter(state_dict):\n    new_state_dict = {}\n    for key in state_dict:\n        value = state_dict[key]\n        new_state_dict[key.replace(\"net.\", \"\")] = value\n    return new_state_dict\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux2_text_encoder.py",
    "content": "def Flux2TextEncoderStateDictConverter(state_dict):\n    rename_dict = {\n        \"multi_modal_projector.linear_1.weight\": \"model.multi_modal_projector.linear_1.weight\",\n        \"multi_modal_projector.linear_2.weight\": \"model.multi_modal_projector.linear_2.weight\",\n        \"multi_modal_projector.norm.weight\": \"model.multi_modal_projector.norm.weight\",\n        \"multi_modal_projector.patch_merger.merging_layer.weight\": \"model.multi_modal_projector.patch_merger.merging_layer.weight\",\n        \"language_model.lm_head.weight\": \"lm_head.weight\",\n    }\n    state_dict_ = {}\n    for k in state_dict:\n        k_ = k\n        k_ = k_.replace(\"language_model.model\", \"model.language_model\")\n        k_ = k_.replace(\"vision_tower\", \"model.vision_tower\")\n        if k_ in rename_dict:\n            k_ = rename_dict[k_]\n        state_dict_[k_] = state_dict[k]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_controlnet.py",
    "content": "import torch\n\n\ndef FluxControlNetStateDictConverter(state_dict):\n    global_rename_dict = {\n        \"context_embedder\": \"context_embedder\",\n        \"x_embedder\": \"x_embedder\",\n        \"time_text_embed.timestep_embedder.linear_1\": \"time_embedder.timestep_embedder.0\",\n        \"time_text_embed.timestep_embedder.linear_2\": \"time_embedder.timestep_embedder.2\",\n        \"time_text_embed.guidance_embedder.linear_1\": \"guidance_embedder.timestep_embedder.0\",\n        \"time_text_embed.guidance_embedder.linear_2\": \"guidance_embedder.timestep_embedder.2\",\n        \"time_text_embed.text_embedder.linear_1\": \"pooled_text_embedder.0\",\n        \"time_text_embed.text_embedder.linear_2\": \"pooled_text_embedder.2\",\n        \"norm_out.linear\": \"final_norm_out.linear\",\n        \"proj_out\": \"final_proj_out\",\n    }\n    rename_dict = {\n        \"proj_out\": \"proj_out\",\n        \"norm1.linear\": \"norm1_a.linear\",\n        \"norm1_context.linear\": \"norm1_b.linear\",\n        \"attn.to_q\": \"attn.a_to_q\",\n        \"attn.to_k\": \"attn.a_to_k\",\n        \"attn.to_v\": \"attn.a_to_v\",\n        \"attn.to_out.0\": \"attn.a_to_out\",\n        \"attn.add_q_proj\": \"attn.b_to_q\",\n        \"attn.add_k_proj\": \"attn.b_to_k\",\n        \"attn.add_v_proj\": \"attn.b_to_v\",\n        \"attn.to_add_out\": \"attn.b_to_out\",\n        \"ff.net.0.proj\": \"ff_a.0\",\n        \"ff.net.2\": \"ff_a.2\",\n        \"ff_context.net.0.proj\": \"ff_b.0\",\n        \"ff_context.net.2\": \"ff_b.2\",\n        \"attn.norm_q\": \"attn.norm_q_a\",\n        \"attn.norm_k\": \"attn.norm_k_a\",\n        \"attn.norm_added_q\": \"attn.norm_q_b\",\n        \"attn.norm_added_k\": \"attn.norm_k_b\",\n    }\n    rename_dict_single = {\n        \"attn.to_q\": \"a_to_q\",\n        \"attn.to_k\": \"a_to_k\",\n        \"attn.to_v\": \"a_to_v\",\n        \"attn.norm_q\": \"norm_q_a\",\n        \"attn.norm_k\": \"norm_k_a\",\n        \"norm.linear\": \"norm.linear\",\n        \"proj_mlp\": \"proj_in_besides_attn\",\n        \"proj_out\": \"proj_out\",\n    }\n    state_dict_ = {}\n\n    for name in state_dict:\n        param = state_dict[name]\n        if name.endswith(\".weight\") or name.endswith(\".bias\"):\n            suffix = \".weight\" if name.endswith(\".weight\") else \".bias\"\n            prefix = name[:-len(suffix)]\n            if prefix in global_rename_dict:\n                state_dict_[global_rename_dict[prefix] + suffix] = param\n            elif prefix.startswith(\"transformer_blocks.\"):\n                names = prefix.split(\".\")\n                names[0] = \"blocks\"\n                middle = \".\".join(names[2:])\n                if middle in rename_dict:\n                    name_ = \".\".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])\n                    state_dict_[name_] = param\n            elif prefix.startswith(\"single_transformer_blocks.\"):\n                names = prefix.split(\".\")\n                names[0] = \"single_blocks\"\n                middle = \".\".join(names[2:])\n                if middle in rename_dict_single:\n                    name_ = \".\".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])\n                    state_dict_[name_] = param\n                else:\n                    state_dict_[name] = param\n            else:\n                state_dict_[name] = param\n    for name in list(state_dict_.keys()):\n        if \".proj_in_besides_attn.\" in name:\n            name_ = name.replace(\".proj_in_besides_attn.\", \".to_qkv_mlp.\")\n            param = torch.concat([\n                state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_q.\")],\n                state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_k.\")],\n                state_dict_[name.replace(\".proj_in_besides_attn.\", f\".a_to_v.\")],\n                state_dict_[name],\n            ], dim=0)\n            state_dict_[name_] = param\n            state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_q.\"))\n            state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_k.\"))\n            state_dict_.pop(name.replace(\".proj_in_besides_attn.\", f\".a_to_v.\"))\n            state_dict_.pop(name)\n    for name in list(state_dict_.keys()):\n        for component in [\"a\", \"b\"]:\n            if f\".{component}_to_q.\" in name:\n                name_ = name.replace(f\".{component}_to_q.\", f\".{component}_to_qkv.\")\n                param = torch.concat([\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                ], dim=0)\n                state_dict_[name_] = param\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\"))\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\"))\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\"))\n    \n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_dit.py",
    "content": "import torch\n\n\ndef FluxDiTStateDictConverter(state_dict):\n    is_nexus_gen = sum([key.startswith(\"pipe.dit.\") for key in state_dict]) > 0\n    if is_nexus_gen:\n        dit_state_dict = {}\n        for key in state_dict:\n            if key.startswith('pipe.dit.'):\n                param = state_dict[key]\n                new_key = key.replace(\"pipe.dit.\", \"\")\n                if new_key.startswith(\"final_norm_out.linear.\"):\n                    param = torch.concat([param[3072:], param[:3072]], dim=0)\n                dit_state_dict[new_key] = param\n        return dit_state_dict\n\n    rename_dict = {\n        \"time_in.in_layer.bias\": \"time_embedder.timestep_embedder.0.bias\",\n        \"time_in.in_layer.weight\": \"time_embedder.timestep_embedder.0.weight\",\n        \"time_in.out_layer.bias\": \"time_embedder.timestep_embedder.2.bias\",\n        \"time_in.out_layer.weight\": \"time_embedder.timestep_embedder.2.weight\",\n        \"txt_in.bias\": \"context_embedder.bias\",\n        \"txt_in.weight\": \"context_embedder.weight\",\n        \"vector_in.in_layer.bias\": \"pooled_text_embedder.0.bias\",\n        \"vector_in.in_layer.weight\": \"pooled_text_embedder.0.weight\",\n        \"vector_in.out_layer.bias\": \"pooled_text_embedder.2.bias\",\n        \"vector_in.out_layer.weight\": \"pooled_text_embedder.2.weight\",\n        \"final_layer.linear.bias\": \"final_proj_out.bias\",\n        \"final_layer.linear.weight\": \"final_proj_out.weight\",\n        \"guidance_in.in_layer.bias\": \"guidance_embedder.timestep_embedder.0.bias\",\n        \"guidance_in.in_layer.weight\": \"guidance_embedder.timestep_embedder.0.weight\",\n        \"guidance_in.out_layer.bias\": \"guidance_embedder.timestep_embedder.2.bias\",\n        \"guidance_in.out_layer.weight\": \"guidance_embedder.timestep_embedder.2.weight\",\n        \"img_in.bias\": \"x_embedder.bias\",\n        \"img_in.weight\": \"x_embedder.weight\",\n        \"final_layer.adaLN_modulation.1.weight\": \"final_norm_out.linear.weight\",\n        \"final_layer.adaLN_modulation.1.bias\": \"final_norm_out.linear.bias\",\n    }\n    suffix_rename_dict = {\n        \"img_attn.norm.key_norm.scale\": \"attn.norm_k_a.weight\",\n        \"img_attn.norm.query_norm.scale\": \"attn.norm_q_a.weight\",\n        \"img_attn.proj.bias\": \"attn.a_to_out.bias\",\n        \"img_attn.proj.weight\": \"attn.a_to_out.weight\",\n        \"img_attn.qkv.bias\": \"attn.a_to_qkv.bias\",\n        \"img_attn.qkv.weight\": \"attn.a_to_qkv.weight\",\n        \"img_mlp.0.bias\": \"ff_a.0.bias\",\n        \"img_mlp.0.weight\": \"ff_a.0.weight\",\n        \"img_mlp.2.bias\": \"ff_a.2.bias\",\n        \"img_mlp.2.weight\": \"ff_a.2.weight\",\n        \"img_mod.lin.bias\": \"norm1_a.linear.bias\",\n        \"img_mod.lin.weight\": \"norm1_a.linear.weight\",\n        \"txt_attn.norm.key_norm.scale\": \"attn.norm_k_b.weight\",\n        \"txt_attn.norm.query_norm.scale\": \"attn.norm_q_b.weight\",\n        \"txt_attn.proj.bias\": \"attn.b_to_out.bias\",\n        \"txt_attn.proj.weight\": \"attn.b_to_out.weight\",\n        \"txt_attn.qkv.bias\": \"attn.b_to_qkv.bias\",\n        \"txt_attn.qkv.weight\": \"attn.b_to_qkv.weight\",\n        \"txt_mlp.0.bias\": \"ff_b.0.bias\",\n        \"txt_mlp.0.weight\": \"ff_b.0.weight\",\n        \"txt_mlp.2.bias\": \"ff_b.2.bias\",\n        \"txt_mlp.2.weight\": \"ff_b.2.weight\",\n        \"txt_mod.lin.bias\": \"norm1_b.linear.bias\",\n        \"txt_mod.lin.weight\": \"norm1_b.linear.weight\",\n\n        \"linear1.bias\": \"to_qkv_mlp.bias\",\n        \"linear1.weight\": \"to_qkv_mlp.weight\",\n        \"linear2.bias\": \"proj_out.bias\",\n        \"linear2.weight\": \"proj_out.weight\",\n        \"modulation.lin.bias\": \"norm.linear.bias\",\n        \"modulation.lin.weight\": \"norm.linear.weight\",\n        \"norm.key_norm.scale\": \"norm_k_a.weight\",\n        \"norm.query_norm.scale\": \"norm_q_a.weight\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        original_name = name\n        if name.startswith(\"model.diffusion_model.\"):\n            name = name[len(\"model.diffusion_model.\"):]\n        names = name.split(\".\")\n        if name in rename_dict:\n            rename = rename_dict[name]\n            state_dict_[rename] = state_dict[original_name]\n        elif names[0] == \"double_blocks\":\n            rename = f\"blocks.{names[1]}.\" + suffix_rename_dict[\".\".join(names[2:])]\n            state_dict_[rename] = state_dict[original_name]\n        elif names[0] == \"single_blocks\":\n            if \".\".join(names[2:]) in suffix_rename_dict:\n                rename = f\"single_blocks.{names[1]}.\" + suffix_rename_dict[\".\".join(names[2:])]\n                state_dict_[rename] = state_dict[original_name]\n        else:\n            pass\n    return state_dict_\n\n\ndef FluxDiTStateDictConverterFromDiffusers(state_dict):\n    global_rename_dict = {\n        \"context_embedder\": \"context_embedder\",\n        \"x_embedder\": \"x_embedder\",\n        \"time_text_embed.timestep_embedder.linear_1\": \"time_embedder.timestep_embedder.0\",\n        \"time_text_embed.timestep_embedder.linear_2\": \"time_embedder.timestep_embedder.2\",\n        \"time_text_embed.guidance_embedder.linear_1\": \"guidance_embedder.timestep_embedder.0\",\n        \"time_text_embed.guidance_embedder.linear_2\": \"guidance_embedder.timestep_embedder.2\",\n        \"time_text_embed.text_embedder.linear_1\": \"pooled_text_embedder.0\",\n        \"time_text_embed.text_embedder.linear_2\": \"pooled_text_embedder.2\",\n        \"norm_out.linear\": \"final_norm_out.linear\",\n        \"proj_out\": \"final_proj_out\",\n    }\n    rename_dict = {\n        \"proj_out\": \"proj_out\",\n        \"norm1.linear\": \"norm1_a.linear\",\n        \"norm1_context.linear\": \"norm1_b.linear\",\n        \"attn.to_q\": \"attn.a_to_q\",\n        \"attn.to_k\": \"attn.a_to_k\",\n        \"attn.to_v\": \"attn.a_to_v\",\n        \"attn.to_out.0\": \"attn.a_to_out\",\n        \"attn.add_q_proj\": \"attn.b_to_q\",\n        \"attn.add_k_proj\": \"attn.b_to_k\",\n        \"attn.add_v_proj\": \"attn.b_to_v\",\n        \"attn.to_add_out\": \"attn.b_to_out\",\n        \"ff.net.0.proj\": \"ff_a.0\",\n        \"ff.net.2\": \"ff_a.2\",\n        \"ff_context.net.0.proj\": \"ff_b.0\",\n        \"ff_context.net.2\": \"ff_b.2\",\n        \"attn.norm_q\": \"attn.norm_q_a\",\n        \"attn.norm_k\": \"attn.norm_k_a\",\n        \"attn.norm_added_q\": \"attn.norm_q_b\",\n        \"attn.norm_added_k\": \"attn.norm_k_b\",\n    }\n    rename_dict_single = {\n        \"attn.to_q\": \"a_to_q\",\n        \"attn.to_k\": \"a_to_k\",\n        \"attn.to_v\": \"a_to_v\",\n        \"attn.norm_q\": \"norm_q_a\",\n        \"attn.norm_k\": \"norm_k_a\",\n        \"norm.linear\": \"norm.linear\",\n        \"proj_mlp\": \"proj_in_besides_attn\",\n        \"proj_out\": \"proj_out\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        param = state_dict[name]\n        if name.endswith(\".weight\") or name.endswith(\".bias\"):\n            suffix = \".weight\" if name.endswith(\".weight\") else \".bias\"\n            prefix = name[:-len(suffix)]\n            if prefix in global_rename_dict:\n                if global_rename_dict[prefix] == \"final_norm_out.linear\":\n                    param = torch.concat([param[3072:], param[:3072]], dim=0)\n                state_dict_[global_rename_dict[prefix] + suffix] = param\n            elif prefix.startswith(\"transformer_blocks.\"):\n                names = prefix.split(\".\")\n                names[0] = \"blocks\"\n                middle = \".\".join(names[2:])\n                if middle in rename_dict:\n                    name_ = \".\".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])\n                    state_dict_[name_] = param\n            elif prefix.startswith(\"single_transformer_blocks.\"):\n                names = prefix.split(\".\")\n                names[0] = \"single_blocks\"\n                middle = \".\".join(names[2:])\n                if middle in rename_dict_single:\n                    name_ = \".\".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])\n                    state_dict_[name_] = param\n                else:\n                    pass\n            else:\n                pass\n    for name in list(state_dict_.keys()):\n        if \"single_blocks.\" in name and \".a_to_q.\" in name:\n            mlp = state_dict_.get(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"), None)\n            if mlp is None:\n                mlp = torch.zeros(4 * state_dict_[name].shape[0],\n                                    *state_dict_[name].shape[1:],\n                                    dtype=state_dict_[name].dtype)\n            else:\n                state_dict_.pop(name.replace(\".a_to_q.\", \".proj_in_besides_attn.\"))\n            param = torch.concat([\n                state_dict_.pop(name),\n                state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_k.\")),\n                state_dict_.pop(name.replace(\".a_to_q.\", \".a_to_v.\")),\n                mlp,\n            ], dim=0)\n            name_ = name.replace(\".a_to_q.\", \".to_qkv_mlp.\")\n            state_dict_[name_] = param\n    for name in list(state_dict_.keys()):\n        for component in [\"a\", \"b\"]:\n            if f\".{component}_to_q.\" in name:\n                name_ = name.replace(f\".{component}_to_q.\", f\".{component}_to_qkv.\")\n                param = torch.concat([\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\")],\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\")],\n                    state_dict_[name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\")],\n                ], dim=0)\n                state_dict_[name_] = param\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_q.\"))\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_k.\"))\n                state_dict_.pop(name.replace(f\".{component}_to_q.\", f\".{component}_to_v.\"))\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_infiniteyou.py",
    "content": "def FluxInfiniteYouImageProjectorStateDictConverter(state_dict):\n    return state_dict['image_proj']"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_ipadapter.py",
    "content": "def FluxIpAdapterStateDictConverter(state_dict):\n    state_dict_ = {}\n    \n    if \"ip_adapter\" in state_dict and isinstance(state_dict[\"ip_adapter\"], dict):\n        for name, param in state_dict[\"ip_adapter\"].items():\n            name_ = 'ipadapter_modules.' + name\n            state_dict_[name_] = param\n        \n        if \"image_proj\" in state_dict:\n            for name, param in state_dict[\"image_proj\"].items():\n                name_ = \"image_proj.\" + name\n                state_dict_[name_] = param\n        return state_dict_\n\n    for key, value in state_dict.items():\n        if key.startswith(\"image_proj.\"):\n            state_dict_[key] = value\n        elif key.startswith(\"ip_adapter.\"):\n            new_key = key.replace(\"ip_adapter.\", \"ipadapter_modules.\")\n            state_dict_[new_key] = value\n        else:\n            pass\n            \n    return state_dict_\n\n\ndef SiglipStateDictConverter(state_dict):\n    new_state_dict = {}\n    for key in state_dict:\n        if key.startswith(\"vision_model.\"):\n            new_state_dict[key] = state_dict[key] \n    return new_state_dict"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py",
    "content": "def FluxTextEncoderClipStateDictConverter(state_dict):\n    rename_dict = {\n        \"text_model.embeddings.token_embedding.weight\": \"token_embedding.weight\",\n        \"text_model.embeddings.position_embedding.weight\": \"position_embeds\",\n        \"text_model.final_layer_norm.weight\": \"final_layer_norm.weight\",\n        \"text_model.final_layer_norm.bias\": \"final_layer_norm.bias\",\n    }\n    attn_rename_dict = {\n        \"self_attn.q_proj\": \"attn.to_q\",\n        \"self_attn.k_proj\": \"attn.to_k\",\n        \"self_attn.v_proj\": \"attn.to_v\",\n        \"self_attn.out_proj\": \"attn.to_out\",\n        \"layer_norm1\": \"layer_norm1\",\n        \"layer_norm2\": \"layer_norm2\",\n        \"mlp.fc1\": \"fc1\",\n        \"mlp.fc2\": \"fc2\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        if name in rename_dict:\n            param = state_dict[name]\n            if name == \"text_model.embeddings.position_embedding.weight\":\n                param = param.reshape((1, param.shape[0], param.shape[1]))\n            state_dict_[rename_dict[name]] = param\n        elif name.startswith(\"text_model.encoder.layers.\"):\n            param = state_dict[name]\n            names = name.split(\".\")\n            layer_id, layer_type, tail = names[3], \".\".join(names[4:-1]), names[-1]\n            name_ = \".\".join([\"encoders\", layer_id, attn_rename_dict[layer_type], tail])\n            state_dict_[name_] = param\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py",
    "content": "def FluxTextEncoderT5StateDictConverter(state_dict):\n    state_dict_ = {i: state_dict[i] for i in state_dict}\n    state_dict_[\"encoder.embed_tokens.weight\"] = state_dict[\"shared.weight\"]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/flux_vae.py",
    "content": "def FluxVAEEncoderStateDictConverter(state_dict):\n    rename_dict = {\n        \"encoder.conv_in.bias\": \"conv_in.bias\",\n        \"encoder.conv_in.weight\": \"conv_in.weight\",\n        \"encoder.conv_out.bias\": \"conv_out.bias\",\n        \"encoder.conv_out.weight\": \"conv_out.weight\",\n        \"encoder.down.0.block.0.conv1.bias\": \"blocks.0.conv1.bias\",\n        \"encoder.down.0.block.0.conv1.weight\": \"blocks.0.conv1.weight\",\n        \"encoder.down.0.block.0.conv2.bias\": \"blocks.0.conv2.bias\",\n        \"encoder.down.0.block.0.conv2.weight\": \"blocks.0.conv2.weight\",\n        \"encoder.down.0.block.0.norm1.bias\": \"blocks.0.norm1.bias\",\n        \"encoder.down.0.block.0.norm1.weight\": \"blocks.0.norm1.weight\",\n        \"encoder.down.0.block.0.norm2.bias\": \"blocks.0.norm2.bias\",\n        \"encoder.down.0.block.0.norm2.weight\": \"blocks.0.norm2.weight\",\n        \"encoder.down.0.block.1.conv1.bias\": \"blocks.1.conv1.bias\",\n        \"encoder.down.0.block.1.conv1.weight\": \"blocks.1.conv1.weight\",\n        \"encoder.down.0.block.1.conv2.bias\": \"blocks.1.conv2.bias\",\n        \"encoder.down.0.block.1.conv2.weight\": \"blocks.1.conv2.weight\",\n        \"encoder.down.0.block.1.norm1.bias\": \"blocks.1.norm1.bias\",\n        \"encoder.down.0.block.1.norm1.weight\": \"blocks.1.norm1.weight\",\n        \"encoder.down.0.block.1.norm2.bias\": \"blocks.1.norm2.bias\",\n        \"encoder.down.0.block.1.norm2.weight\": \"blocks.1.norm2.weight\",\n        \"encoder.down.0.downsample.conv.bias\": \"blocks.2.conv.bias\",\n        \"encoder.down.0.downsample.conv.weight\": \"blocks.2.conv.weight\",\n        \"encoder.down.1.block.0.conv1.bias\": \"blocks.3.conv1.bias\",\n        \"encoder.down.1.block.0.conv1.weight\": \"blocks.3.conv1.weight\",\n        \"encoder.down.1.block.0.conv2.bias\": \"blocks.3.conv2.bias\",\n        \"encoder.down.1.block.0.conv2.weight\": \"blocks.3.conv2.weight\",\n        \"encoder.down.1.block.0.nin_shortcut.bias\": \"blocks.3.conv_shortcut.bias\",\n        \"encoder.down.1.block.0.nin_shortcut.weight\": \"blocks.3.conv_shortcut.weight\",\n        \"encoder.down.1.block.0.norm1.bias\": \"blocks.3.norm1.bias\",\n        \"encoder.down.1.block.0.norm1.weight\": \"blocks.3.norm1.weight\",\n        \"encoder.down.1.block.0.norm2.bias\": \"blocks.3.norm2.bias\",\n        \"encoder.down.1.block.0.norm2.weight\": \"blocks.3.norm2.weight\",\n        \"encoder.down.1.block.1.conv1.bias\": \"blocks.4.conv1.bias\",\n        \"encoder.down.1.block.1.conv1.weight\": \"blocks.4.conv1.weight\",\n        \"encoder.down.1.block.1.conv2.bias\": \"blocks.4.conv2.bias\",\n        \"encoder.down.1.block.1.conv2.weight\": \"blocks.4.conv2.weight\",\n        \"encoder.down.1.block.1.norm1.bias\": \"blocks.4.norm1.bias\",\n        \"encoder.down.1.block.1.norm1.weight\": \"blocks.4.norm1.weight\",\n        \"encoder.down.1.block.1.norm2.bias\": \"blocks.4.norm2.bias\",\n        \"encoder.down.1.block.1.norm2.weight\": \"blocks.4.norm2.weight\",\n        \"encoder.down.1.downsample.conv.bias\": \"blocks.5.conv.bias\",\n        \"encoder.down.1.downsample.conv.weight\": \"blocks.5.conv.weight\",\n        \"encoder.down.2.block.0.conv1.bias\": \"blocks.6.conv1.bias\",\n        \"encoder.down.2.block.0.conv1.weight\": \"blocks.6.conv1.weight\",\n        \"encoder.down.2.block.0.conv2.bias\": \"blocks.6.conv2.bias\",\n        \"encoder.down.2.block.0.conv2.weight\": \"blocks.6.conv2.weight\",\n        \"encoder.down.2.block.0.nin_shortcut.bias\": \"blocks.6.conv_shortcut.bias\",\n        \"encoder.down.2.block.0.nin_shortcut.weight\": \"blocks.6.conv_shortcut.weight\",\n        \"encoder.down.2.block.0.norm1.bias\": \"blocks.6.norm1.bias\",\n        \"encoder.down.2.block.0.norm1.weight\": \"blocks.6.norm1.weight\",\n        \"encoder.down.2.block.0.norm2.bias\": \"blocks.6.norm2.bias\",\n        \"encoder.down.2.block.0.norm2.weight\": \"blocks.6.norm2.weight\",\n        \"encoder.down.2.block.1.conv1.bias\": \"blocks.7.conv1.bias\",\n        \"encoder.down.2.block.1.conv1.weight\": \"blocks.7.conv1.weight\",\n        \"encoder.down.2.block.1.conv2.bias\": \"blocks.7.conv2.bias\",\n        \"encoder.down.2.block.1.conv2.weight\": \"blocks.7.conv2.weight\",\n        \"encoder.down.2.block.1.norm1.bias\": \"blocks.7.norm1.bias\",\n        \"encoder.down.2.block.1.norm1.weight\": \"blocks.7.norm1.weight\",\n        \"encoder.down.2.block.1.norm2.bias\": \"blocks.7.norm2.bias\",\n        \"encoder.down.2.block.1.norm2.weight\": \"blocks.7.norm2.weight\",\n        \"encoder.down.2.downsample.conv.bias\": \"blocks.8.conv.bias\",\n        \"encoder.down.2.downsample.conv.weight\": \"blocks.8.conv.weight\",\n        \"encoder.down.3.block.0.conv1.bias\": \"blocks.9.conv1.bias\",\n        \"encoder.down.3.block.0.conv1.weight\": \"blocks.9.conv1.weight\",\n        \"encoder.down.3.block.0.conv2.bias\": \"blocks.9.conv2.bias\",\n        \"encoder.down.3.block.0.conv2.weight\": \"blocks.9.conv2.weight\",\n        \"encoder.down.3.block.0.norm1.bias\": \"blocks.9.norm1.bias\",\n        \"encoder.down.3.block.0.norm1.weight\": \"blocks.9.norm1.weight\",\n        \"encoder.down.3.block.0.norm2.bias\": \"blocks.9.norm2.bias\",\n        \"encoder.down.3.block.0.norm2.weight\": \"blocks.9.norm2.weight\",\n        \"encoder.down.3.block.1.conv1.bias\": \"blocks.10.conv1.bias\",\n        \"encoder.down.3.block.1.conv1.weight\": \"blocks.10.conv1.weight\",\n        \"encoder.down.3.block.1.conv2.bias\": \"blocks.10.conv2.bias\",\n        \"encoder.down.3.block.1.conv2.weight\": \"blocks.10.conv2.weight\",\n        \"encoder.down.3.block.1.norm1.bias\": \"blocks.10.norm1.bias\",\n        \"encoder.down.3.block.1.norm1.weight\": \"blocks.10.norm1.weight\",\n        \"encoder.down.3.block.1.norm2.bias\": \"blocks.10.norm2.bias\",\n        \"encoder.down.3.block.1.norm2.weight\": \"blocks.10.norm2.weight\",\n        \"encoder.mid.attn_1.k.bias\": \"blocks.12.transformer_blocks.0.to_k.bias\",\n        \"encoder.mid.attn_1.k.weight\": \"blocks.12.transformer_blocks.0.to_k.weight\",\n        \"encoder.mid.attn_1.norm.bias\": \"blocks.12.norm.bias\",\n        \"encoder.mid.attn_1.norm.weight\": \"blocks.12.norm.weight\",\n        \"encoder.mid.attn_1.proj_out.bias\": \"blocks.12.transformer_blocks.0.to_out.bias\",\n        \"encoder.mid.attn_1.proj_out.weight\": \"blocks.12.transformer_blocks.0.to_out.weight\",\n        \"encoder.mid.attn_1.q.bias\": \"blocks.12.transformer_blocks.0.to_q.bias\",\n        \"encoder.mid.attn_1.q.weight\": \"blocks.12.transformer_blocks.0.to_q.weight\",\n        \"encoder.mid.attn_1.v.bias\": \"blocks.12.transformer_blocks.0.to_v.bias\",\n        \"encoder.mid.attn_1.v.weight\": \"blocks.12.transformer_blocks.0.to_v.weight\",\n        \"encoder.mid.block_1.conv1.bias\": \"blocks.11.conv1.bias\",\n        \"encoder.mid.block_1.conv1.weight\": \"blocks.11.conv1.weight\",\n        \"encoder.mid.block_1.conv2.bias\": \"blocks.11.conv2.bias\",\n        \"encoder.mid.block_1.conv2.weight\": \"blocks.11.conv2.weight\",\n        \"encoder.mid.block_1.norm1.bias\": \"blocks.11.norm1.bias\",\n        \"encoder.mid.block_1.norm1.weight\": \"blocks.11.norm1.weight\",\n        \"encoder.mid.block_1.norm2.bias\": \"blocks.11.norm2.bias\",\n        \"encoder.mid.block_1.norm2.weight\": \"blocks.11.norm2.weight\",\n        \"encoder.mid.block_2.conv1.bias\": \"blocks.13.conv1.bias\",\n        \"encoder.mid.block_2.conv1.weight\": \"blocks.13.conv1.weight\",\n        \"encoder.mid.block_2.conv2.bias\": \"blocks.13.conv2.bias\",\n        \"encoder.mid.block_2.conv2.weight\": \"blocks.13.conv2.weight\",\n        \"encoder.mid.block_2.norm1.bias\": \"blocks.13.norm1.bias\",\n        \"encoder.mid.block_2.norm1.weight\": \"blocks.13.norm1.weight\",\n        \"encoder.mid.block_2.norm2.bias\": \"blocks.13.norm2.bias\",\n        \"encoder.mid.block_2.norm2.weight\": \"blocks.13.norm2.weight\",\n        \"encoder.norm_out.bias\": \"conv_norm_out.bias\",\n        \"encoder.norm_out.weight\": \"conv_norm_out.weight\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        if name in rename_dict:\n            param = state_dict[name]\n            state_dict_[rename_dict[name]] = param\n    return state_dict_\n\n\ndef FluxVAEDecoderStateDictConverter(state_dict):\n    rename_dict = {\n        \"decoder.conv_in.bias\": \"conv_in.bias\",\n        \"decoder.conv_in.weight\": \"conv_in.weight\",\n        \"decoder.conv_out.bias\": \"conv_out.bias\",\n        \"decoder.conv_out.weight\": \"conv_out.weight\",\n        \"decoder.mid.attn_1.k.bias\": \"blocks.1.transformer_blocks.0.to_k.bias\",\n        \"decoder.mid.attn_1.k.weight\": \"blocks.1.transformer_blocks.0.to_k.weight\",\n        \"decoder.mid.attn_1.norm.bias\": \"blocks.1.norm.bias\",\n        \"decoder.mid.attn_1.norm.weight\": \"blocks.1.norm.weight\",\n        \"decoder.mid.attn_1.proj_out.bias\": \"blocks.1.transformer_blocks.0.to_out.bias\",\n        \"decoder.mid.attn_1.proj_out.weight\": \"blocks.1.transformer_blocks.0.to_out.weight\",\n        \"decoder.mid.attn_1.q.bias\": \"blocks.1.transformer_blocks.0.to_q.bias\",\n        \"decoder.mid.attn_1.q.weight\": \"blocks.1.transformer_blocks.0.to_q.weight\",\n        \"decoder.mid.attn_1.v.bias\": \"blocks.1.transformer_blocks.0.to_v.bias\",\n        \"decoder.mid.attn_1.v.weight\": \"blocks.1.transformer_blocks.0.to_v.weight\",\n        \"decoder.mid.block_1.conv1.bias\": \"blocks.0.conv1.bias\",\n        \"decoder.mid.block_1.conv1.weight\": \"blocks.0.conv1.weight\",\n        \"decoder.mid.block_1.conv2.bias\": \"blocks.0.conv2.bias\",\n        \"decoder.mid.block_1.conv2.weight\": \"blocks.0.conv2.weight\",\n        \"decoder.mid.block_1.norm1.bias\": \"blocks.0.norm1.bias\",\n        \"decoder.mid.block_1.norm1.weight\": \"blocks.0.norm1.weight\",\n        \"decoder.mid.block_1.norm2.bias\": \"blocks.0.norm2.bias\",\n        \"decoder.mid.block_1.norm2.weight\": \"blocks.0.norm2.weight\",\n        \"decoder.mid.block_2.conv1.bias\": \"blocks.2.conv1.bias\",\n        \"decoder.mid.block_2.conv1.weight\": \"blocks.2.conv1.weight\",\n        \"decoder.mid.block_2.conv2.bias\": \"blocks.2.conv2.bias\",\n        \"decoder.mid.block_2.conv2.weight\": \"blocks.2.conv2.weight\",\n        \"decoder.mid.block_2.norm1.bias\": \"blocks.2.norm1.bias\",\n        \"decoder.mid.block_2.norm1.weight\": \"blocks.2.norm1.weight\",\n        \"decoder.mid.block_2.norm2.bias\": \"blocks.2.norm2.bias\",\n        \"decoder.mid.block_2.norm2.weight\": \"blocks.2.norm2.weight\",\n        \"decoder.norm_out.bias\": \"conv_norm_out.bias\",\n        \"decoder.norm_out.weight\": \"conv_norm_out.weight\",\n        \"decoder.up.0.block.0.conv1.bias\": \"blocks.15.conv1.bias\",\n        \"decoder.up.0.block.0.conv1.weight\": \"blocks.15.conv1.weight\",\n        \"decoder.up.0.block.0.conv2.bias\": \"blocks.15.conv2.bias\",\n        \"decoder.up.0.block.0.conv2.weight\": \"blocks.15.conv2.weight\",\n        \"decoder.up.0.block.0.nin_shortcut.bias\": \"blocks.15.conv_shortcut.bias\",\n        \"decoder.up.0.block.0.nin_shortcut.weight\": \"blocks.15.conv_shortcut.weight\",\n        \"decoder.up.0.block.0.norm1.bias\": \"blocks.15.norm1.bias\",\n        \"decoder.up.0.block.0.norm1.weight\": \"blocks.15.norm1.weight\",\n        \"decoder.up.0.block.0.norm2.bias\": \"blocks.15.norm2.bias\",\n        \"decoder.up.0.block.0.norm2.weight\": \"blocks.15.norm2.weight\",\n        \"decoder.up.0.block.1.conv1.bias\": \"blocks.16.conv1.bias\",\n        \"decoder.up.0.block.1.conv1.weight\": \"blocks.16.conv1.weight\",\n        \"decoder.up.0.block.1.conv2.bias\": \"blocks.16.conv2.bias\",\n        \"decoder.up.0.block.1.conv2.weight\": \"blocks.16.conv2.weight\",\n        \"decoder.up.0.block.1.norm1.bias\": \"blocks.16.norm1.bias\",\n        \"decoder.up.0.block.1.norm1.weight\": \"blocks.16.norm1.weight\",\n        \"decoder.up.0.block.1.norm2.bias\": \"blocks.16.norm2.bias\",\n        \"decoder.up.0.block.1.norm2.weight\": \"blocks.16.norm2.weight\",\n        \"decoder.up.0.block.2.conv1.bias\": \"blocks.17.conv1.bias\",\n        \"decoder.up.0.block.2.conv1.weight\": \"blocks.17.conv1.weight\",\n        \"decoder.up.0.block.2.conv2.bias\": \"blocks.17.conv2.bias\",\n        \"decoder.up.0.block.2.conv2.weight\": \"blocks.17.conv2.weight\",\n        \"decoder.up.0.block.2.norm1.bias\": \"blocks.17.norm1.bias\",\n        \"decoder.up.0.block.2.norm1.weight\": \"blocks.17.norm1.weight\",\n        \"decoder.up.0.block.2.norm2.bias\": \"blocks.17.norm2.bias\",\n        \"decoder.up.0.block.2.norm2.weight\": \"blocks.17.norm2.weight\",\n        \"decoder.up.1.block.0.conv1.bias\": \"blocks.11.conv1.bias\",\n        \"decoder.up.1.block.0.conv1.weight\": \"blocks.11.conv1.weight\",\n        \"decoder.up.1.block.0.conv2.bias\": \"blocks.11.conv2.bias\",\n        \"decoder.up.1.block.0.conv2.weight\": \"blocks.11.conv2.weight\",\n        \"decoder.up.1.block.0.nin_shortcut.bias\": \"blocks.11.conv_shortcut.bias\",\n        \"decoder.up.1.block.0.nin_shortcut.weight\": \"blocks.11.conv_shortcut.weight\",\n        \"decoder.up.1.block.0.norm1.bias\": \"blocks.11.norm1.bias\",\n        \"decoder.up.1.block.0.norm1.weight\": \"blocks.11.norm1.weight\",\n        \"decoder.up.1.block.0.norm2.bias\": \"blocks.11.norm2.bias\",\n        \"decoder.up.1.block.0.norm2.weight\": \"blocks.11.norm2.weight\",\n        \"decoder.up.1.block.1.conv1.bias\": \"blocks.12.conv1.bias\",\n        \"decoder.up.1.block.1.conv1.weight\": \"blocks.12.conv1.weight\",\n        \"decoder.up.1.block.1.conv2.bias\": \"blocks.12.conv2.bias\",\n        \"decoder.up.1.block.1.conv2.weight\": \"blocks.12.conv2.weight\",\n        \"decoder.up.1.block.1.norm1.bias\": \"blocks.12.norm1.bias\",\n        \"decoder.up.1.block.1.norm1.weight\": \"blocks.12.norm1.weight\",\n        \"decoder.up.1.block.1.norm2.bias\": \"blocks.12.norm2.bias\",\n        \"decoder.up.1.block.1.norm2.weight\": \"blocks.12.norm2.weight\",\n        \"decoder.up.1.block.2.conv1.bias\": \"blocks.13.conv1.bias\",\n        \"decoder.up.1.block.2.conv1.weight\": \"blocks.13.conv1.weight\",\n        \"decoder.up.1.block.2.conv2.bias\": \"blocks.13.conv2.bias\",\n        \"decoder.up.1.block.2.conv2.weight\": \"blocks.13.conv2.weight\",\n        \"decoder.up.1.block.2.norm1.bias\": \"blocks.13.norm1.bias\",\n        \"decoder.up.1.block.2.norm1.weight\": \"blocks.13.norm1.weight\",\n        \"decoder.up.1.block.2.norm2.bias\": \"blocks.13.norm2.bias\",\n        \"decoder.up.1.block.2.norm2.weight\": \"blocks.13.norm2.weight\",\n        \"decoder.up.1.upsample.conv.bias\": \"blocks.14.conv.bias\",\n        \"decoder.up.1.upsample.conv.weight\": \"blocks.14.conv.weight\",\n        \"decoder.up.2.block.0.conv1.bias\": \"blocks.7.conv1.bias\",\n        \"decoder.up.2.block.0.conv1.weight\": \"blocks.7.conv1.weight\",\n        \"decoder.up.2.block.0.conv2.bias\": \"blocks.7.conv2.bias\",\n        \"decoder.up.2.block.0.conv2.weight\": \"blocks.7.conv2.weight\",\n        \"decoder.up.2.block.0.norm1.bias\": \"blocks.7.norm1.bias\",\n        \"decoder.up.2.block.0.norm1.weight\": \"blocks.7.norm1.weight\",\n        \"decoder.up.2.block.0.norm2.bias\": \"blocks.7.norm2.bias\",\n        \"decoder.up.2.block.0.norm2.weight\": \"blocks.7.norm2.weight\",\n        \"decoder.up.2.block.1.conv1.bias\": \"blocks.8.conv1.bias\",\n        \"decoder.up.2.block.1.conv1.weight\": \"blocks.8.conv1.weight\",\n        \"decoder.up.2.block.1.conv2.bias\": \"blocks.8.conv2.bias\",\n        \"decoder.up.2.block.1.conv2.weight\": \"blocks.8.conv2.weight\",\n        \"decoder.up.2.block.1.norm1.bias\": \"blocks.8.norm1.bias\",\n        \"decoder.up.2.block.1.norm1.weight\": \"blocks.8.norm1.weight\",\n        \"decoder.up.2.block.1.norm2.bias\": \"blocks.8.norm2.bias\",\n        \"decoder.up.2.block.1.norm2.weight\": \"blocks.8.norm2.weight\",\n        \"decoder.up.2.block.2.conv1.bias\": \"blocks.9.conv1.bias\",\n        \"decoder.up.2.block.2.conv1.weight\": \"blocks.9.conv1.weight\",\n        \"decoder.up.2.block.2.conv2.bias\": \"blocks.9.conv2.bias\",\n        \"decoder.up.2.block.2.conv2.weight\": \"blocks.9.conv2.weight\",\n        \"decoder.up.2.block.2.norm1.bias\": \"blocks.9.norm1.bias\",\n        \"decoder.up.2.block.2.norm1.weight\": \"blocks.9.norm1.weight\",\n        \"decoder.up.2.block.2.norm2.bias\": \"blocks.9.norm2.bias\",\n        \"decoder.up.2.block.2.norm2.weight\": \"blocks.9.norm2.weight\",\n        \"decoder.up.2.upsample.conv.bias\": \"blocks.10.conv.bias\",\n        \"decoder.up.2.upsample.conv.weight\": \"blocks.10.conv.weight\",\n        \"decoder.up.3.block.0.conv1.bias\": \"blocks.3.conv1.bias\",\n        \"decoder.up.3.block.0.conv1.weight\": \"blocks.3.conv1.weight\",\n        \"decoder.up.3.block.0.conv2.bias\": \"blocks.3.conv2.bias\",\n        \"decoder.up.3.block.0.conv2.weight\": \"blocks.3.conv2.weight\",\n        \"decoder.up.3.block.0.norm1.bias\": \"blocks.3.norm1.bias\",\n        \"decoder.up.3.block.0.norm1.weight\": \"blocks.3.norm1.weight\",\n        \"decoder.up.3.block.0.norm2.bias\": \"blocks.3.norm2.bias\",\n        \"decoder.up.3.block.0.norm2.weight\": \"blocks.3.norm2.weight\",\n        \"decoder.up.3.block.1.conv1.bias\": \"blocks.4.conv1.bias\",\n        \"decoder.up.3.block.1.conv1.weight\": \"blocks.4.conv1.weight\",\n        \"decoder.up.3.block.1.conv2.bias\": \"blocks.4.conv2.bias\",\n        \"decoder.up.3.block.1.conv2.weight\": \"blocks.4.conv2.weight\",\n        \"decoder.up.3.block.1.norm1.bias\": \"blocks.4.norm1.bias\",\n        \"decoder.up.3.block.1.norm1.weight\": \"blocks.4.norm1.weight\",\n        \"decoder.up.3.block.1.norm2.bias\": \"blocks.4.norm2.bias\",\n        \"decoder.up.3.block.1.norm2.weight\": \"blocks.4.norm2.weight\",\n        \"decoder.up.3.block.2.conv1.bias\": \"blocks.5.conv1.bias\",\n        \"decoder.up.3.block.2.conv1.weight\": \"blocks.5.conv1.weight\",\n        \"decoder.up.3.block.2.conv2.bias\": \"blocks.5.conv2.bias\",\n        \"decoder.up.3.block.2.conv2.weight\": \"blocks.5.conv2.weight\",\n        \"decoder.up.3.block.2.norm1.bias\": \"blocks.5.norm1.bias\",\n        \"decoder.up.3.block.2.norm1.weight\": \"blocks.5.norm1.weight\",\n        \"decoder.up.3.block.2.norm2.bias\": \"blocks.5.norm2.bias\",\n        \"decoder.up.3.block.2.norm2.weight\": \"blocks.5.norm2.weight\",\n        \"decoder.up.3.upsample.conv.bias\": \"blocks.6.conv.bias\",\n        \"decoder.up.3.upsample.conv.weight\": \"blocks.6.conv.weight\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        if name in rename_dict:\n            param = state_dict[name]\n            state_dict_[rename_dict[name]] = param\n    return state_dict_\n\n\ndef FluxVAEEncoderStateDictConverterDiffusers(state_dict):\n    # architecture\n    block_types = [\n        'ResnetBlock', 'ResnetBlock', 'DownSampler',\n        'ResnetBlock', 'ResnetBlock', 'DownSampler',\n        'ResnetBlock', 'ResnetBlock', 'DownSampler',\n        'ResnetBlock', 'ResnetBlock',\n        'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'\n    ]\n\n    # Rename each parameter\n    local_rename_dict = {\n        \"quant_conv\": \"quant_conv\",\n        \"encoder.conv_in\": \"conv_in\",\n        \"encoder.mid_block.attentions.0.group_norm\": \"blocks.12.norm\",\n        \"encoder.mid_block.attentions.0.to_q\": \"blocks.12.transformer_blocks.0.to_q\",\n        \"encoder.mid_block.attentions.0.to_k\": \"blocks.12.transformer_blocks.0.to_k\",\n        \"encoder.mid_block.attentions.0.to_v\": \"blocks.12.transformer_blocks.0.to_v\",\n        \"encoder.mid_block.attentions.0.to_out.0\": \"blocks.12.transformer_blocks.0.to_out\",\n        \"encoder.mid_block.resnets.0.norm1\": \"blocks.11.norm1\",\n        \"encoder.mid_block.resnets.0.conv1\": \"blocks.11.conv1\",\n        \"encoder.mid_block.resnets.0.norm2\": \"blocks.11.norm2\",\n        \"encoder.mid_block.resnets.0.conv2\": \"blocks.11.conv2\",\n        \"encoder.mid_block.resnets.1.norm1\": \"blocks.13.norm1\",\n        \"encoder.mid_block.resnets.1.conv1\": \"blocks.13.conv1\",\n        \"encoder.mid_block.resnets.1.norm2\": \"blocks.13.norm2\",\n        \"encoder.mid_block.resnets.1.conv2\": \"blocks.13.conv2\",\n        \"encoder.conv_norm_out\": \"conv_norm_out\",\n        \"encoder.conv_out\": \"conv_out\",\n    }\n    name_list = sorted([name for name in state_dict])\n    rename_dict = {}\n    block_id = {\"ResnetBlock\": -1, \"DownSampler\": -1, \"UpSampler\": -1}\n    last_block_type_with_id = {\"ResnetBlock\": \"\", \"DownSampler\": \"\", \"UpSampler\": \"\"}\n    for name in name_list:\n        names = name.split(\".\")\n        name_prefix = \".\".join(names[:-1])\n        if name_prefix in local_rename_dict:\n            rename_dict[name] = local_rename_dict[name_prefix] + \".\" + names[-1]\n        elif name.startswith(\"encoder.down_blocks\"):\n            block_type = {\"resnets\": \"ResnetBlock\", \"downsamplers\": \"DownSampler\", \"upsamplers\": \"UpSampler\"}[names[3]]\n            block_type_with_id = \".\".join(names[:5])\n            if block_type_with_id != last_block_type_with_id[block_type]:\n                block_id[block_type] += 1\n            last_block_type_with_id[block_type] = block_type_with_id\n            while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:\n                block_id[block_type] += 1\n            block_type_with_id = \".\".join(names[:5])\n            names = [\"blocks\", str(block_id[block_type])] + names[5:]\n            rename_dict[name] = \".\".join(names)\n\n    # Convert state_dict\n    state_dict_ = {}\n    for name in state_dict:\n        if name in rename_dict:\n            state_dict_[rename_dict[name]] = state_dict[name]\n    return state_dict_\n\n\ndef FluxVAEDecoderStateDictConverterDiffusers(state_dict):\n    # architecture\n        block_types = [\n            'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',\n            'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',\n            'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',\n            'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',\n            'ResnetBlock', 'ResnetBlock', 'ResnetBlock'\n        ]\n\n        # Rename each parameter\n        local_rename_dict = {\n            \"post_quant_conv\": \"post_quant_conv\",\n            \"decoder.conv_in\": \"conv_in\",\n            \"decoder.mid_block.attentions.0.group_norm\": \"blocks.1.norm\",\n            \"decoder.mid_block.attentions.0.to_q\": \"blocks.1.transformer_blocks.0.to_q\",\n            \"decoder.mid_block.attentions.0.to_k\": \"blocks.1.transformer_blocks.0.to_k\",\n            \"decoder.mid_block.attentions.0.to_v\": \"blocks.1.transformer_blocks.0.to_v\",\n            \"decoder.mid_block.attentions.0.to_out.0\": \"blocks.1.transformer_blocks.0.to_out\",\n            \"decoder.mid_block.resnets.0.norm1\": \"blocks.0.norm1\",\n            \"decoder.mid_block.resnets.0.conv1\": \"blocks.0.conv1\",\n            \"decoder.mid_block.resnets.0.norm2\": \"blocks.0.norm2\",\n            \"decoder.mid_block.resnets.0.conv2\": \"blocks.0.conv2\",\n            \"decoder.mid_block.resnets.1.norm1\": \"blocks.2.norm1\",\n            \"decoder.mid_block.resnets.1.conv1\": \"blocks.2.conv1\",\n            \"decoder.mid_block.resnets.1.norm2\": \"blocks.2.norm2\",\n            \"decoder.mid_block.resnets.1.conv2\": \"blocks.2.conv2\",\n            \"decoder.conv_norm_out\": \"conv_norm_out\",\n            \"decoder.conv_out\": \"conv_out\",\n        }\n        name_list = sorted([name for name in state_dict])\n        rename_dict = {}\n        block_id = {\"ResnetBlock\": 2, \"DownSampler\": 2, \"UpSampler\": 2}\n        last_block_type_with_id = {\"ResnetBlock\": \"\", \"DownSampler\": \"\", \"UpSampler\": \"\"}\n        for name in name_list:\n            names = name.split(\".\")\n            name_prefix = \".\".join(names[:-1])\n            if name_prefix in local_rename_dict:\n                rename_dict[name] = local_rename_dict[name_prefix] + \".\" + names[-1]\n            elif name.startswith(\"decoder.up_blocks\"):\n                block_type = {\"resnets\": \"ResnetBlock\", \"downsamplers\": \"DownSampler\", \"upsamplers\": \"UpSampler\"}[names[3]]\n                block_type_with_id = \".\".join(names[:5])\n                if block_type_with_id != last_block_type_with_id[block_type]:\n                    block_id[block_type] += 1\n                last_block_type_with_id[block_type] = block_type_with_id\n                while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:\n                    block_id[block_type] += 1\n                block_type_with_id = \".\".join(names[:5])\n                names = [\"blocks\", str(block_id[block_type])] + names[5:]\n                rename_dict[name] = \".\".join(names)\n\n        # Convert state_dict\n        state_dict_ = {}\n        for name in state_dict:\n            if name in rename_dict:\n                state_dict_[rename_dict[name]] = state_dict[name]\n        return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/ltx2_audio_vae.py",
    "content": "def LTX2AudioEncoderStateDictConverter(state_dict):\n    # Not used\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"audio_vae.encoder.\"):\n            new_name = name.replace(\"audio_vae.encoder.\", \"\")\n            state_dict_[new_name] = state_dict[name]\n        elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n            new_name = name.replace(\"audio_vae.per_channel_statistics.\", \"per_channel_statistics.\")\n            state_dict_[new_name] = state_dict[name]\n    return state_dict_\n\n\ndef LTX2AudioDecoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"audio_vae.decoder.\"):\n            new_name = name.replace(\"audio_vae.decoder.\", \"\")\n            state_dict_[new_name] = state_dict[name]\n        elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n            new_name = name.replace(\"audio_vae.per_channel_statistics.\", \"per_channel_statistics.\")\n            state_dict_[new_name] = state_dict[name]\n    return state_dict_\n\n\ndef LTX2VocoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"vocoder.\"):\n            new_name = name[len(\"vocoder.\"):]\n            state_dict_[new_name] = state_dict[name]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/ltx2_dit.py",
    "content": "def LTXModelStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"model.diffusion_model.\"):\n            new_name = name.replace(\"model.diffusion_model.\", \"\")\n            if new_name.startswith(\"audio_embeddings_connector.\") or new_name.startswith(\"video_embeddings_connector.\"):\n                continue\n            state_dict_[new_name] = state_dict[name]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/ltx2_text_encoder.py",
    "content": "def LTX2TextEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for key in state_dict:\n        if key.startswith(\"language_model.model.\"):\n            new_key = key.replace(\"language_model.model.\", \"model.language_model.\")\n        elif key.startswith(\"vision_tower.\"):\n            new_key = key.replace(\"vision_tower.\", \"model.vision_tower.\")\n        elif key.startswith(\"multi_modal_projector.\"):\n            new_key = key.replace(\"multi_modal_projector.\", \"model.multi_modal_projector.\")\n        elif key.startswith(\"language_model.lm_head.\"):\n            new_key = key.replace(\"language_model.lm_head.\", \"lm_head.\")\n        else:\n            continue\n        state_dict_[new_key] = state_dict[key]\n    state_dict_[\"lm_head.weight\"] = state_dict_.get(\"model.language_model.embed_tokens.weight\")\n    return state_dict_\n\n\ndef LTX2TextEncoderPostModulesStateDictConverter(state_dict):\n    state_dict_ = {}\n    for key in state_dict:\n        if key.startswith(\"text_embedding_projection.\"):\n            new_key = key.replace(\"text_embedding_projection.\", \"feature_extractor_linear.\")\n        elif key.startswith(\"model.diffusion_model.video_embeddings_connector.\"):\n            new_key = key.replace(\"model.diffusion_model.video_embeddings_connector.\", \"embeddings_connector.\")\n        elif key.startswith(\"model.diffusion_model.audio_embeddings_connector.\"):\n            new_key = key.replace(\"model.diffusion_model.audio_embeddings_connector.\", \"audio_embeddings_connector.\")\n        else:\n            continue\n        state_dict_[new_key] = state_dict[key]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/ltx2_video_vae.py",
    "content": "def LTX2VideoEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"vae.encoder.\"):\n            new_name = name.replace(\"vae.encoder.\", \"\")\n            state_dict_[new_name] = state_dict[name]\n        elif name.startswith(\"vae.per_channel_statistics.\"):\n            new_name = name.replace(\"vae.per_channel_statistics.\", \"per_channel_statistics.\")\n            if new_name not in [\"per_channel_statistics.channel\", \"per_channel_statistics.mean-of-stds\", \"per_channel_statistics.mean-of-stds_over_std-of-means\"]:\n                state_dict_[new_name] = state_dict[name]\n    return state_dict_\n\n\ndef LTX2VideoDecoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"vae.decoder.\"):\n            new_name = name.replace(\"vae.decoder.\", \"\")\n            state_dict_[new_name] = state_dict[name]\n        elif name.startswith(\"vae.per_channel_statistics.\"):\n            new_name = name.replace(\"vae.per_channel_statistics.\", \"per_channel_statistics.\")\n            if new_name not in [\"per_channel_statistics.channel\", \"per_channel_statistics.mean-of-stds\", \"per_channel_statistics.mean-of-stds_over_std-of-means\"]:\n                state_dict_[new_name] = state_dict[name]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/nexus_gen.py",
    "content": "def NexusGenAutoregressiveModelStateDictConverter(state_dict):\n    new_state_dict = {}\n    for key in state_dict:\n        value = state_dict[key]\n        new_state_dict[\"model.\" + key] = value\n    return new_state_dict"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/nexus_gen_projector.py",
    "content": "def NexusGenMergerStateDictConverter(state_dict):\n    merger_state_dict = {}\n    for key in state_dict:\n        if key.startswith('embedding_merger.'):\n            value = state_dict[key]\n            new_key = key.replace(\"embedding_merger.\", \"\")\n            merger_state_dict[new_key] = value\n    return merger_state_dict\n\ndef NexusGenAdapterStateDictConverter(state_dict):\n    adapter_state_dict = {}\n    for key in state_dict:\n        if key.startswith('adapter.'):\n            adapter_state_dict[key] = state_dict[key]\n    return adapter_state_dict"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py",
    "content": "def QwenImageTextEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for k in state_dict:\n        v = state_dict[k]\n        if k.startswith(\"visual.\"):\n            k = \"model.\" + k\n        elif k.startswith(\"model.\"):\n            k = k.replace(\"model.\", \"model.language_model.\")\n        state_dict_[k] = v\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/step1x_connector.py",
    "content": "def Qwen2ConnectorStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"connector.\"):\n            name_ = name[len(\"connector.\"):]\n            state_dict_[name_] = state_dict[name]\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py",
    "content": "def WanAnimateAdapterStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"pose_patch_embedding.\") or name.startswith(\"face_adapter\") or name.startswith(\"face_encoder\") or name.startswith(\"motion_encoder\"):\n            state_dict_[name] = state_dict[name]\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_dit.py",
    "content": "def WanVideoDiTFromDiffusers(state_dict):\n    rename_dict = {\n        \"blocks.0.attn1.norm_k.weight\": \"blocks.0.self_attn.norm_k.weight\",\n        \"blocks.0.attn1.norm_q.weight\": \"blocks.0.self_attn.norm_q.weight\",\n        \"blocks.0.attn1.to_k.bias\": \"blocks.0.self_attn.k.bias\",\n        \"blocks.0.attn1.to_k.weight\": \"blocks.0.self_attn.k.weight\",\n        \"blocks.0.attn1.to_out.0.bias\": \"blocks.0.self_attn.o.bias\",\n        \"blocks.0.attn1.to_out.0.weight\": \"blocks.0.self_attn.o.weight\",\n        \"blocks.0.attn1.to_q.bias\": \"blocks.0.self_attn.q.bias\",\n        \"blocks.0.attn1.to_q.weight\": \"blocks.0.self_attn.q.weight\",\n        \"blocks.0.attn1.to_v.bias\": \"blocks.0.self_attn.v.bias\",\n        \"blocks.0.attn1.to_v.weight\": \"blocks.0.self_attn.v.weight\",\n        \"blocks.0.attn2.norm_k.weight\": \"blocks.0.cross_attn.norm_k.weight\",\n        \"blocks.0.attn2.norm_q.weight\": \"blocks.0.cross_attn.norm_q.weight\",\n        \"blocks.0.attn2.to_k.bias\": \"blocks.0.cross_attn.k.bias\",\n        \"blocks.0.attn2.to_k.weight\": \"blocks.0.cross_attn.k.weight\",\n        \"blocks.0.attn2.to_out.0.bias\": \"blocks.0.cross_attn.o.bias\",\n        \"blocks.0.attn2.to_out.0.weight\": \"blocks.0.cross_attn.o.weight\",\n        \"blocks.0.attn2.to_q.bias\": \"blocks.0.cross_attn.q.bias\",\n        \"blocks.0.attn2.to_q.weight\": \"blocks.0.cross_attn.q.weight\",\n        \"blocks.0.attn2.to_v.bias\": \"blocks.0.cross_attn.v.bias\",\n        \"blocks.0.attn2.to_v.weight\": \"blocks.0.cross_attn.v.weight\",\n        \"blocks.0.attn2.add_k_proj.bias\":\"blocks.0.cross_attn.k_img.bias\",\n        \"blocks.0.attn2.add_k_proj.weight\":\"blocks.0.cross_attn.k_img.weight\",\n        \"blocks.0.attn2.add_v_proj.bias\":\"blocks.0.cross_attn.v_img.bias\",\n        \"blocks.0.attn2.add_v_proj.weight\":\"blocks.0.cross_attn.v_img.weight\",\n        \"blocks.0.attn2.norm_added_k.weight\":\"blocks.0.cross_attn.norm_k_img.weight\",\n        \"blocks.0.ffn.net.0.proj.bias\": \"blocks.0.ffn.0.bias\",\n        \"blocks.0.ffn.net.0.proj.weight\": \"blocks.0.ffn.0.weight\",\n        \"blocks.0.ffn.net.2.bias\": \"blocks.0.ffn.2.bias\",\n        \"blocks.0.ffn.net.2.weight\": \"blocks.0.ffn.2.weight\",\n        \"blocks.0.norm2.bias\": \"blocks.0.norm3.bias\",\n        \"blocks.0.norm2.weight\": \"blocks.0.norm3.weight\",\n        \"blocks.0.scale_shift_table\": \"blocks.0.modulation\",\n        \"condition_embedder.text_embedder.linear_1.bias\": \"text_embedding.0.bias\",\n        \"condition_embedder.text_embedder.linear_1.weight\": \"text_embedding.0.weight\",\n        \"condition_embedder.text_embedder.linear_2.bias\": \"text_embedding.2.bias\",\n        \"condition_embedder.text_embedder.linear_2.weight\": \"text_embedding.2.weight\",\n        \"condition_embedder.time_embedder.linear_1.bias\": \"time_embedding.0.bias\",\n        \"condition_embedder.time_embedder.linear_1.weight\": \"time_embedding.0.weight\",\n        \"condition_embedder.time_embedder.linear_2.bias\": \"time_embedding.2.bias\",\n        \"condition_embedder.time_embedder.linear_2.weight\": \"time_embedding.2.weight\",\n        \"condition_embedder.time_proj.bias\": \"time_projection.1.bias\",\n        \"condition_embedder.time_proj.weight\": \"time_projection.1.weight\",\n        \"condition_embedder.image_embedder.ff.net.0.proj.bias\":\"img_emb.proj.1.bias\",\n        \"condition_embedder.image_embedder.ff.net.0.proj.weight\":\"img_emb.proj.1.weight\",\n        \"condition_embedder.image_embedder.ff.net.2.bias\":\"img_emb.proj.3.bias\",\n        \"condition_embedder.image_embedder.ff.net.2.weight\":\"img_emb.proj.3.weight\",\n        \"condition_embedder.image_embedder.norm1.bias\":\"img_emb.proj.0.bias\",\n        \"condition_embedder.image_embedder.norm1.weight\":\"img_emb.proj.0.weight\",\n        \"condition_embedder.image_embedder.norm2.bias\":\"img_emb.proj.4.bias\",\n        \"condition_embedder.image_embedder.norm2.weight\":\"img_emb.proj.4.weight\",\n        \"patch_embedding.bias\": \"patch_embedding.bias\",\n        \"patch_embedding.weight\": \"patch_embedding.weight\",\n        \"scale_shift_table\": \"head.modulation\",\n        \"proj_out.bias\": \"head.head.bias\",\n        \"proj_out.weight\": \"head.head.weight\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        if name in rename_dict:\n            state_dict_[rename_dict[name]] = state_dict[name]\n        else:\n            name_ = \".\".join(name.split(\".\")[:1] + [\"0\"] + name.split(\".\")[2:])\n            if name_ in rename_dict:\n                name_ = rename_dict[name_]\n                name_ = \".\".join(name_.split(\".\")[:1] + [name.split(\".\")[1]] + name_.split(\".\")[2:])\n                state_dict_[name_] = state_dict[name]\n    return state_dict_\n\n\ndef WanVideoDiTStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"vace\"):\n            continue\n        if name.split(\".\")[0] in [\"pose_patch_embedding\", \"face_adapter\", \"face_encoder\", \"motion_encoder\"]:\n            continue\n        name_ = name\n        if name_.startswith(\"model.\"):\n            name_ = name_[len(\"model.\"):]\n        state_dict_[name_] = state_dict[name]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_image_encoder.py",
    "content": "def WanImageEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name.startswith(\"textual.\"):\n            continue\n        name_ = \"model.\" + name\n        state_dict_[name_] = state_dict[name]\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_mot.py",
    "content": "def WanVideoMotStateDictConverter(state_dict):\n    rename_dict = {\n        \"blocks.0.attn1.norm_k.weight\": \"blocks.0.self_attn.norm_k.weight\",\n        \"blocks.0.attn1.norm_q.weight\": \"blocks.0.self_attn.norm_q.weight\",\n        \"blocks.0.attn1.to_k.bias\": \"blocks.0.self_attn.k.bias\",\n        \"blocks.0.attn1.to_k.weight\": \"blocks.0.self_attn.k.weight\",\n        \"blocks.0.attn1.to_out.0.bias\": \"blocks.0.self_attn.o.bias\",\n        \"blocks.0.attn1.to_out.0.weight\": \"blocks.0.self_attn.o.weight\",\n        \"blocks.0.attn1.to_q.bias\": \"blocks.0.self_attn.q.bias\",\n        \"blocks.0.attn1.to_q.weight\": \"blocks.0.self_attn.q.weight\",\n        \"blocks.0.attn1.to_v.bias\": \"blocks.0.self_attn.v.bias\",\n        \"blocks.0.attn1.to_v.weight\": \"blocks.0.self_attn.v.weight\",\n        \"blocks.0.attn2.norm_k.weight\": \"blocks.0.cross_attn.norm_k.weight\",\n        \"blocks.0.attn2.norm_q.weight\": \"blocks.0.cross_attn.norm_q.weight\",\n        \"blocks.0.attn2.to_k.bias\": \"blocks.0.cross_attn.k.bias\",\n        \"blocks.0.attn2.to_k.weight\": \"blocks.0.cross_attn.k.weight\",\n        \"blocks.0.attn2.to_out.0.bias\": \"blocks.0.cross_attn.o.bias\",\n        \"blocks.0.attn2.to_out.0.weight\": \"blocks.0.cross_attn.o.weight\",\n        \"blocks.0.attn2.to_q.bias\": \"blocks.0.cross_attn.q.bias\",\n        \"blocks.0.attn2.to_q.weight\": \"blocks.0.cross_attn.q.weight\",\n        \"blocks.0.attn2.to_v.bias\": \"blocks.0.cross_attn.v.bias\",\n        \"blocks.0.attn2.to_v.weight\": \"blocks.0.cross_attn.v.weight\",\n        \"blocks.0.attn2.add_k_proj.bias\":\"blocks.0.cross_attn.k_img.bias\",\n        \"blocks.0.attn2.add_k_proj.weight\":\"blocks.0.cross_attn.k_img.weight\",\n        \"blocks.0.attn2.add_v_proj.bias\":\"blocks.0.cross_attn.v_img.bias\",\n        \"blocks.0.attn2.add_v_proj.weight\":\"blocks.0.cross_attn.v_img.weight\",\n        \"blocks.0.attn2.norm_added_k.weight\":\"blocks.0.cross_attn.norm_k_img.weight\",\n        \"blocks.0.ffn.net.0.proj.bias\": \"blocks.0.ffn.0.bias\",\n        \"blocks.0.ffn.net.0.proj.weight\": \"blocks.0.ffn.0.weight\",\n        \"blocks.0.ffn.net.2.bias\": \"blocks.0.ffn.2.bias\",\n        \"blocks.0.ffn.net.2.weight\": \"blocks.0.ffn.2.weight\",\n        \"blocks.0.norm2.bias\": \"blocks.0.norm3.bias\",\n        \"blocks.0.norm2.weight\": \"blocks.0.norm3.weight\",\n        \"blocks.0.scale_shift_table\": \"blocks.0.modulation\",\n        \"condition_embedder.text_embedder.linear_1.bias\": \"text_embedding.0.bias\",\n        \"condition_embedder.text_embedder.linear_1.weight\": \"text_embedding.0.weight\",\n        \"condition_embedder.text_embedder.linear_2.bias\": \"text_embedding.2.bias\",\n        \"condition_embedder.text_embedder.linear_2.weight\": \"text_embedding.2.weight\",\n        \"condition_embedder.time_embedder.linear_1.bias\": \"time_embedding.0.bias\",\n        \"condition_embedder.time_embedder.linear_1.weight\": \"time_embedding.0.weight\",\n        \"condition_embedder.time_embedder.linear_2.bias\": \"time_embedding.2.bias\",\n        \"condition_embedder.time_embedder.linear_2.weight\": \"time_embedding.2.weight\",\n        \"condition_embedder.time_proj.bias\": \"time_projection.1.bias\",\n        \"condition_embedder.time_proj.weight\": \"time_projection.1.weight\",\n        \"condition_embedder.image_embedder.ff.net.0.proj.bias\":\"img_emb.proj.1.bias\",\n        \"condition_embedder.image_embedder.ff.net.0.proj.weight\":\"img_emb.proj.1.weight\",\n        \"condition_embedder.image_embedder.ff.net.2.bias\":\"img_emb.proj.3.bias\",\n        \"condition_embedder.image_embedder.ff.net.2.weight\":\"img_emb.proj.3.weight\",\n        \"condition_embedder.image_embedder.norm1.bias\":\"img_emb.proj.0.bias\",\n        \"condition_embedder.image_embedder.norm1.weight\":\"img_emb.proj.0.weight\",\n        \"condition_embedder.image_embedder.norm2.bias\":\"img_emb.proj.4.bias\",\n        \"condition_embedder.image_embedder.norm2.weight\":\"img_emb.proj.4.weight\",\n        \"patch_embedding.bias\": \"patch_embedding.bias\",\n        \"patch_embedding.weight\": \"patch_embedding.weight\",\n        \"scale_shift_table\": \"head.modulation\",\n        \"proj_out.bias\": \"head.head.bias\",\n        \"proj_out.weight\": \"head.head.weight\",\n    }\n    mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)\n    mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}\n    state_dict_ = {}\n    for name in state_dict:\n        if \"_mot_ref\" not in name:\n            continue\n        param = state_dict[name]\n        name = name.replace(\"_mot_ref\", \"\")\n        if name in rename_dict:\n            state_dict_[rename_dict[name]] = param\n        else:\n            if name.split(\".\")[1].isdigit():\n                block_id = int(name.split(\".\")[1])\n                name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))\n            name_ = \".\".join(name.split(\".\")[:1] + [\"0\"] + name.split(\".\")[2:])\n            if name_ in rename_dict:\n                name_ = rename_dict[name_]\n                name_ = \".\".join(name_.split(\".\")[:1] + [name.split(\".\")[1]] + name_.split(\".\")[2:])\n                state_dict_[name_] = param\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_vace.py",
    "content": "def VaceWanModelDictConverter(state_dict):\n    state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith(\"vace\")}\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wan_video_vae.py",
    "content": "def WanVideoVAEStateDictConverter(state_dict):\n    state_dict_ = {}\n    if 'model_state' in state_dict:\n        state_dict = state_dict['model_state']\n    for name in state_dict:\n        state_dict_['model.' + name] = state_dict[name]\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py",
    "content": "def WanS2VAudioEncoderStateDictConverter(state_dict):\n    rename_dict = {\n        \"model.wav2vec2.encoder.pos_conv_embed.conv.weight_g\": \"model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0\",\n        \"model.wav2vec2.encoder.pos_conv_embed.conv.weight_v\": \"model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1\",\n    }\n    state_dict_ = {}\n    for name in state_dict:\n        name_ = \"model.\" + name\n        if name_ in rename_dict:\n            name_ = rename_dict[name_]\n        state_dict_[name_] = state_dict[name]\n    return state_dict_\n"
  },
  {
    "path": "diffsynth/utils/state_dict_converters/z_image_text_encoder.py",
    "content": "def ZImageTextEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for name in state_dict:\n        if name != \"lm_head.weight\":\n            state_dict_[name] = state_dict[name]\n    return state_dict_"
  },
  {
    "path": "diffsynth/utils/xfuser/__init__.py",
    "content": "from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks\n"
  },
  {
    "path": "diffsynth/utils/xfuser/xdit_context_parallel.py",
    "content": "import torch\nfrom typing import Optional\nfrom einops import rearrange\nfrom yunchang.kernels import AttnType\nfrom xfuser.core.distributed import (get_sequence_parallel_rank,\n                                     get_sequence_parallel_world_size,\n                                     get_sp_group)\nfrom xfuser.core.long_ctx_attention import xFuserLongContextAttention\n\nfrom ... import IS_NPU_AVAILABLE\nfrom ...core.device import parse_nccl_backend, parse_device_type\nfrom ...core.gradient import gradient_checkpoint_forward\n\n\ndef initialize_usp(device_type):\n    import torch.distributed as dist\n    from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment\n    dist.init_process_group(backend=parse_nccl_backend(device_type), init_method=\"env://\")\n    init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())\n    initialize_model_parallel(\n        sequence_parallel_degree=dist.get_world_size(),\n        ring_degree=1,\n        ulysses_degree=dist.get_world_size(),\n    )\n    getattr(torch, device_type).set_device(dist.get_rank())\n\n\ndef sinusoidal_embedding_1d(dim, position):\n    sinusoid = torch.outer(position.type(torch.float64), torch.pow(\n        10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))\n    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)\n    return x.to(position.dtype)\n\ndef pad_freqs(original_tensor, target_len):\n    seq_len, s1, s2 = original_tensor.shape\n    pad_size = target_len - seq_len\n    original_tensor_device = original_tensor.device\n    if original_tensor.device == \"npu\":\n        original_tensor = original_tensor.cpu()\n    padding_tensor = torch.ones(\n        pad_size,\n        s1,\n        s2,\n        dtype=original_tensor.dtype,\n        device=original_tensor.device)\n    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)\n    return padded_tensor\n    \ndef rope_apply(x, freqs, num_heads):\n    x = rearrange(x, \"b s (n d) -> b s n d\", n=num_heads)\n    s_per_rank = x.shape[1]\n\n    x_out = torch.view_as_complex(x.to(torch.float64).reshape(\n        x.shape[0], x.shape[1], x.shape[2], -1, 2))\n\n    sp_size = get_sequence_parallel_world_size()\n    sp_rank = get_sequence_parallel_rank()\n    freqs = pad_freqs(freqs, s_per_rank * sp_size)\n    freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]\n    freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == \"npu\" else freqs_rank\n    x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)\n    return x_out.to(x.dtype)\n\ndef usp_dit_forward(self,\n            x: torch.Tensor,\n            timestep: torch.Tensor,\n            context: torch.Tensor,\n            clip_feature: Optional[torch.Tensor] = None,\n            y: Optional[torch.Tensor] = None,\n            use_gradient_checkpointing: bool = False,\n            use_gradient_checkpointing_offload: bool = False,\n            **kwargs,\n            ):\n    t = self.time_embedding(\n        sinusoidal_embedding_1d(self.freq_dim, timestep))\n    t_mod = self.time_projection(t).unflatten(1, (6, self.dim))\n    context = self.text_embedding(context)\n    \n    if self.has_image_input:\n        x = torch.cat([x, y], dim=1)  # (b, c_x + c_y, f, h, w)\n        clip_embdding = self.img_emb(clip_feature)\n        context = torch.cat([clip_embdding, context], dim=1)\n    \n    x, (f, h, w) = self.patchify(x)\n    \n    freqs = torch.cat([\n        self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n        self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n        self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n    ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)\n\n    # Context Parallel\n    chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)\n    pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]\n    chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]\n    x = chunks[get_sequence_parallel_rank()]\n\n    for block in self.blocks:\n        if self.training:\n            x = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing,\n                use_gradient_checkpointing_offload,\n                x, context, t_mod, freqs\n            )\n        else:\n            x = block(x, context, t_mod, freqs)\n\n    x = self.head(x, t)\n\n    # Context Parallel\n    x = get_sp_group().all_gather(x, dim=1)\n    x = x[:, :-pad_shape] if pad_shape > 0 else x\n\n    # unpatchify\n    x = self.unpatchify(x, (f, h, w))\n    return x\n\n\ndef usp_attn_forward(self, x, freqs):\n    q = self.norm_q(self.q(x))\n    k = self.norm_k(self.k(x))\n    v = self.v(x)\n\n    q = rope_apply(q, freqs, self.num_heads)\n    k = rope_apply(k, freqs, self.num_heads)\n    q = rearrange(q, \"b s (n d) -> b s n d\", n=self.num_heads)\n    k = rearrange(k, \"b s (n d) -> b s n d\", n=self.num_heads)\n    v = rearrange(v, \"b s (n d) -> b s n d\", n=self.num_heads)\n\n    attn_type = AttnType.FA\n    ring_impl_type = \"basic\"\n    if IS_NPU_AVAILABLE:\n        attn_type = AttnType.NPU\n        ring_impl_type = \"basic_npu\"\n    x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(\n        None,\n        query=q,\n        key=k,\n        value=v,\n    )\n    x = x.flatten(2)\n\n    del q, k, v\n    getattr(torch, parse_device_type(x.device)).empty_cache()\n    return self.o(x)\n\n\ndef get_current_chunk(x, dim=1):\n    chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)\n    ndims = len(chunks[0].shape)\n    pad_list = [0] * (2 * ndims)\n    pad_end_index = 2 * (ndims - 1 - dim) + 1\n    max_size = chunks[0].size(dim)\n    chunks = [\n        torch.nn.functional.pad(\n            chunk, \n            tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]), \n            value=0\n        ) \n        for chunk in chunks\n    ]\n    x = chunks[get_sequence_parallel_rank()]\n    return x\n\n\ndef gather_all_chunks(x, seq_len=None, dim=1):\n    x = get_sp_group().all_gather(x, dim=dim)\n    if seq_len is not None:\n        slices = [slice(None)] * x.ndim\n        slices[dim] = slice(0, seq_len)\n        x = x[tuple(slices)]\n    return x\n"
  },
  {
    "path": "diffsynth/version.py",
    "content": "# Make sure to modify __release_datetime__ to release time when making official release.\n__version__ = '2.0.0'\n# default release datetime for branches under active development is set\n# to be a time far-far-away-into-the-future\n__release_datetime__ = '2099-10-13 08:56:12'"
  },
  {
    "path": "docs/en/.readthedocs.yaml",
    "content": "# .readthedocs.yaml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the OS, Python version and other tools you might need\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\n# Build documentation in the \"docs/\" directory with Sphinx\nsphinx:\n  configuration: docs/en/conf.py\n\n# Optionally build your docs in additional formats such as PDF and ePub\n# formats:\n#    - pdf\n#    - epub\n\n# Optional but recommended, declare the Python requirements required\n# to build your documentation\n# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html\npython:\n   install:\n      - requirements: docs/requirements.txt\n"
  },
  {
    "path": "docs/en/API_Reference/core/attention.md",
    "content": "# `diffsynth.core.attention`: Attention Mechanism Implementation\n\n`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).\n\n## Attention Mechanism\n\nThe attention mechanism is a model structure proposed in the paper [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762). In the original paper, the attention mechanism is implemented according to the following formula:\n\n$$\n\\text{Attention}(Q, K, V) = \\text{Softmax}\\left(\n    \\frac{QK^T}{\\sqrt{d_k}}\n\\right)\nV.\n$$\n\nIn `PyTorch`, it can be implemented with the following code:\n```python\nimport torch\n\ndef attention(query, key, value):\n    scale_factor = 1 / query.size(-1)**0.5\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    return attn_weight @ value\n\nquery = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nkey = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nvalue = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\noutput_1 = attention(query, key, value)\n```\n\nThe dimensions of `query`, `key`, and `value` are $(b, n, s, d)$:\n* $b$: Batch size\n* $n$: Number of attention heads\n* $s$: Sequence length\n* $d$: Dimension of each attention head\n\nThis computation does not include any trainable parameters. Modern transformer architectures will pass through Linear layers before and after this computation, but the \"attention mechanism\" discussed in this article refers only to the computation in the above code, not including these calculations.\n\n## More Efficient Implementations\n\nNote that the dimension of the Attention Score in the attention mechanism ( $\\text{Softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)$ in the formula, `attn_weight` in the code) is $(b, n, s, s)$, where the sequence length $s$ is typically very large, causing the time and space complexity of computation to reach quadratic level. Taking image generation models as an example, when the width and height of the image increase to 2 times, the sequence length increases to 4 times, and the computational load and memory requirements increase to 16 times. To avoid high computational costs, more efficient attention mechanism implementations are needed, including:\n* Flash Attention 3: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2407.08608)\n* Flash Attention 2: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2307.08691)\n* Sage Attention: [GitHub](https://github.com/thu-ml/SageAttention), [Paper](https://arxiv.org/abs/2505.11594)\n* xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)\n* PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)\n\nTo call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).\n\n```python\nfrom diffsynth.core.attention import attention_forward\nimport torch\n\ndef attention(query, key, value):\n    scale_factor = 1 / query.size(-1)**0.5\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    return attn_weight @ value\n\nquery = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nkey = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nvalue = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\noutput_1 = attention(query, key, value)\noutput_2 = attention_forward(query, key, value)\nprint((output_1 - output_2).abs().mean())\n```\n\nPlease note that acceleration will introduce errors, but in most cases, the error is negligible.\n\n## Developer Guide\n\nWhen integrating new models into `DiffSynth-Studio`, developers can decide whether to call `attention_forward` in `diffsynth.core.attention`, but we expect models to prioritize calling this module as much as possible, so that new attention mechanism implementations can take effect directly on these models.\n\n## Best Practices\n\n**In most cases, we recommend directly using the native `PyTorch` implementation without installing any additional packages.** Although other attention mechanism implementations can accelerate, the acceleration effect is relatively limited, and in a few cases, compatibility and precision issues may arise.\n\nIn addition, efficient attention mechanism implementations will gradually be integrated into `PyTorch`. The `scaled_dot_product_attention` in `PyTorch` version 2.9.0 has already integrated Flash Attention 2. We still provide this interface in `DiffSynth-Studio` to allow some aggressive acceleration schemes to quickly move toward application, even though they still need time to be verified for stability."
  },
  {
    "path": "docs/en/API_Reference/core/data.md",
    "content": "# `diffsynth.core.data`: Data Processing Operators and Universal Dataset\n\n## Data Processing Operators\n\n### Available Data Processing Operators\n\n`diffsynth.core.data` provides a series of data processing operators for data processing, including:\n\n* Data format conversion operators\n    * `ToInt`: Convert to int format\n    * `ToFloat`: Convert to float format\n    * `ToStr`: Convert to str format\n    * `ToList`: Convert to list format, wrapping this data in a list\n    * `ToAbsolutePath`: Convert relative paths to absolute paths\n* File loading operators\n    * `LoadImage`: Read image files\n    * `LoadVideo`: Read video files\n    * `LoadAudio`: Read audio files\n    * `LoadGIF`: Read GIF files\n    * `LoadTorchPickle`: Read binary files saved by [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) [This operator may cause code injection attacks in binary files, please use with caution!]\n* Media file processing operators\n    * `ImageCropAndResize`: Crop and resize images\n* Meta operators\n    * `SequencialProcess`: Route each data in the sequence to an operator\n    * `RouteByExtensionName`: Route to specific operators by file extension\n    * `RouteByType`: Route to specific operators by data type\n\n### Operator Usage\n\nData operators are connected with the `>>` symbol to form data processing pipelines, for example:\n\n```python\nfrom diffsynth.core.data.operators import *\n\ndata = \"image.jpg\"\ndata_pipeline = ToAbsolutePath(base_path=\"/data\") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512)\ndata = data_pipeline(data)\n```\n\nAfter passing through each operator, the data is processed in sequence:\n\n* `ToAbsolutePath(base_path=\"/data\")`: `\"/data/image.jpg\"`\n* `LoadImage()`: `<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F8E7AAEFC10>`\n* `ImageCropAndResize(max_pixels=512*512)`: `<PIL.Image.Image image mode=RGB size=512x512 at 0x7F8E7A936F20>`\n\nWe can compose functionally complete data pipelines, for example, the default video data operator for the universal dataset is:\n\n```python\nRouteByType(operator_map=[\n    (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[\n        ((\"jpg\", \"jpeg\", \"png\", \"webp\"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),\n        ((\"gif\",), LoadGIF(\n            num_frames, time_division_factor, time_division_remainder,\n            frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n        )),\n        ((\"mp4\", \"avi\", \"mov\", \"wmv\", \"mkv\", \"flv\", \"webm\"), LoadVideo(\n            num_frames, time_division_factor, time_division_remainder,\n            frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n        )),\n    ])),\n])\n```\n\nIt includes the following logic:\n\n* If the data is of type `str`\n    * If it's a `\"jpg\", \"jpeg\", \"png\", \"webp\"` type file\n        * Load this image\n        * Crop and scale to a specific resolution\n        * Pack into a list, treating it as a single-frame video\n    * If it's a `\"gif\"` type file\n        * Load the GIF file content\n        * Crop and scale each frame to a specific resolution\n    * If it's a `\"mp4\", \"avi\", \"mov\", \"wmv\", \"mkv\", \"flv\", \"webm\"` type file\n        * Load the video file content\n        * Crop and scale each frame to a specific resolution\n* If the data is not of type `str`, an error is reported\n\n## Universal Dataset\n\n`diffsynth.core.data` provides a unified dataset implementation. The dataset requires the following parameters:\n\n* `base_path`: Root directory. If the dataset contains relative paths to image files, this field needs to be filled in to load the files pointed to by these paths\n* `metadata_path`: Metadata directory, records the file paths of all metadata, supports `csv`, `json`, `jsonl` formats\n* `repeat`: Data repetition count, defaults to 1, this parameter affects the number of training steps in an epoch\n* `data_file_keys`: Data field names that need to be loaded, for example `(image, edit_image)`\n* `main_data_operator`: Main loading operator, needs to assemble the data processing pipeline through data processing operators\n* `special_operator_map`: Special operator mapping, operator mappings built for fields that require special processing\n\n### Metadata\n\nThe dataset's `metadata_path` points to a metadata file, supporting `csv`, `json`, `jsonl` formats. The following provides examples:\n\n* `csv` format: High readability, does not support list data, small memory footprint\n\n```csv\nimage,prompt\nimage_1.jpg,\"a dog\"\nimage_2.jpg,\"a cat\"\n```\n\n* `json` format: High readability, supports list data, large memory footprint\n\n```json\n[\n    {\n        \"image\": \"image_1.jpg\",\n        \"prompt\": \"a dog\"\n    },\n    {\n        \"image\": \"image_2.jpg\",\n        \"prompt\": \"a cat\"\n    }\n]\n```\n\n* `jsonl` format: Low readability, supports list data, small memory footprint\n\n```json\n{\"image\": \"image_1.jpg\", \"prompt\": \"a dog\"}\n{\"image\": \"image_2.jpg\", \"prompt\": \"a cat\"}\n```\n\nHow to choose the best metadata format?\n\n* If the data volume is large, reaching tens of millions, since `json` file parsing requires additional memory, it's not available. Please use `csv` or `jsonl` format\n* If the dataset contains list data, such as edit models that require multiple images as input, since `csv` format cannot store list format data, it's not available. Please use `json` or `jsonl` format\n\n### Data Loading Logic\n\nWhen no additional settings are made, the dataset defaults to outputting data from the metadata set. Image and video file paths will be output in string format. To load these files, you need to set `data_file_keys`, `main_data_operator`, and `special_operator_map`.\n\nIn the data processing flow, processing is done according to the following logic:\n* If the field is in `special_operator_map`, call the corresponding operator in `special_operator_map` for processing\n* If the field is not in `special_operator_map`\n    * If the field is in `data_file_keys`, call the `main_data_operator` operator for processing\n    * If the field is not in `data_file_keys`, no processing is done\n\n`special_operator_map` can be used to implement special data processing. For example, in the model [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B), the input character face video `animate_face_video` is processed at a fixed resolution, inconsistent with the output video. Therefore, this field is processed by a dedicated operator:\n\n```python\nspecial_operator_map={\n    \"animate_face_video\": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),\n}\n```\n\n### Other Notes\n\nWhen the data volume is too small, you can appropriately increase `repeat` to extend the training time of a single epoch, avoiding frequent model saving that generates considerable overhead.\n\nWhen data volume * `repeat` exceeds $10^9$, we observe that the dataset speed becomes significantly slower. This seems to be a `PyTorch` bug, and we are not sure if newer versions of `PyTorch` have fixed this issue."
  },
  {
    "path": "docs/en/API_Reference/core/gradient.md",
    "content": "# `diffsynth.core.gradient`: Gradient Checkpointing and Offload\n\n`diffsynth.core.gradient` provides encapsulated gradient checkpointing and its Offload version for model training.\n\n## Gradient Checkpointing\n\nGradient checkpointing is a technique used to reduce memory usage during training. We provide an example to help you understand this technique. Here is a simple model structure:\n\n```python\nimport torch\n\nclass ToyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        return self.activation(x)\n\nmodel = ToyModel()\nx = torch.randn((2, 3))\ny = model(x)\n```\n\nIn this model structure, the input parameter $x$ passes through the Sigmoid activation function to obtain the output value $y=\\frac{1}{1+e^{-x}}$.\n\nDuring the training process, assuming our loss function value is $\\mathcal L$, when backpropagating gradients, we obtain $\\frac{\\partial \\mathcal L}{\\partial y}$. At this point, we need to calculate $\\frac{\\partial \\mathcal L}{\\partial x}$. It's not difficult to find that $\\frac{\\partial y}{\\partial x}=y(1-y)$, and thus $\\frac{\\partial \\mathcal L}{\\partial x}=\\frac{\\partial \\mathcal L}{\\partial y}\\frac{\\partial y}{\\partial x}=\\frac{\\partial \\mathcal L}{\\partial y}y(1-y)$. If we save the value of $y$ during the model's forward propagation and directly compute $y(1-y)$ during gradient backpropagation, this will avoid complex exp computations, speeding up the calculation. However, this requires additional memory to store the intermediate variable $y$.\n\nWhen gradient checkpointing is not enabled, the training framework will default to storing all intermediate variables that assist gradient computation, thereby achieving optimal computational speed. When gradient checkpointing is enabled, intermediate variables are not stored, but the input parameter $x$ is still stored, reducing memory usage. During gradient backpropagation, these variables need to be recomputed, slowing down the calculation.\n\n## Enabling Gradient Checkpointing and Its Offload\n\n`gradient_checkpoint_forward` in `diffsynth.core.gradient` implements gradient checkpointing and its Offload. Refer to the following code for calling:\n\n```python\nimport torch\nfrom diffsynth.core.gradient import gradient_checkpoint_forward\n\nclass ToyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        return self.activation(x)\n\nmodel = ToyModel()\nx = torch.randn((2, 3))\ny = gradient_checkpoint_forward(\n    model,\n    use_gradient_checkpointing=True,\n    use_gradient_checkpointing_offload=False,\n    x=x,\n)\n```\n\n* When `use_gradient_checkpointing=False` and `use_gradient_checkpointing_offload=False`, the computation process is exactly the same as the original computation, not affecting the model's inference and training. You can directly integrate it into your code.\n* When `use_gradient_checkpointing=True` and `use_gradient_checkpointing_offload=False`, gradient checkpointing is enabled.\n* When `use_gradient_checkpointing_offload=True`, gradient checkpointing is enabled, and all gradient checkpoint input parameters are stored in memory, further reducing memory usage and slowing down computation.\n\n## Best Practices\n\n> Q: Where should gradient checkpointing be enabled?\n> \n> A: When enabling gradient checkpointing for the entire model, computational efficiency and memory usage are not optimal. We need to set fine-grained gradient checkpoints, but we don't want to add too much complicated code to the framework. Therefore, we recommend implementing it in the `model_fn` of `Pipeline`, for example, `model_fn_qwen_image` in `diffsynth/pipelines/qwen_image.py`, enabling gradient checkpointing at the Block level without modifying any code in the model structure.\n\n> Q: When should gradient checkpointing be enabled?\n> \n> A: As model parameters become increasingly large, gradient checkpointing has become a necessary training technique. Gradient checkpointing usually needs to be enabled. Gradient checkpointing Offload should only be enabled in models where activation values occupy excessive memory (such as video generation models)."
  },
  {
    "path": "docs/en/API_Reference/core/loader.md",
    "content": "# `diffsynth.core.loader`: Model Download and Loading\n\nThis document introduces the model download and loading functionalities in `diffsynth.core.loader`.\n\n## ModelConfig\n\n`ModelConfig` in `diffsynth.core.loader` is used to annotate model download sources, local paths, VRAM management configurations, and other information.\n\n### Downloading and Loading Models from Remote Sources\n\nTaking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).\n\nBy default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\",\n    origin_file_pattern=\"model.safetensors\",\n)\n# Download models\nconfig.download_if_necessary()\nprint(config.path)\n```\n\nAfter calling `download_if_necessary`, the model will be automatically downloaded, and the path will be returned to `config.path`.\n\n### Loading Models from Local Paths\n\nIf loading models from local paths, you need to fill in `path`:\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(path=\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\")\n```\n\nIf the model contains multiple shard files, input them in list form:\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(path=[\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n])\n```\n\n### VRAM Management Configuration\n\n`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](../../Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.\n\n## Model File Loading\n\n`diffsynth.core.loader` provides a unified `load_state_dict` for loading state dicts from model files.\n\nLoading a single model file:\n\n```python\nfrom diffsynth.core import load_state_dict\n\nstate_dict = load_state_dict(\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\")\n```\n\nLoading multiple model files (merged into one state dict):\n\n```python\nfrom diffsynth.core import load_state_dict\n\nstate_dict = load_state_dict([\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n])\n```\n\n## Model Hash\n\nModel hash is used to determine the model type. The hash value can be obtained through `hash_model_file`:\n\n```python\nfrom diffsynth.core import hash_model_file\n\nprint(hash_model_file(\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"))\n```\n\nThe hash value of multiple model files can also be calculated, which is equivalent to calculating the model hash value after merging the state dict:\n\n```python\nfrom diffsynth.core import hash_model_file\n\nprint(hash_model_file([\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n]))\n```\n\nThe model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.**\n\nBy [writing model Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.\n\n## Model Loading\n\n`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](../../API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](../../API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.\n\nHere is a usage example with Disk Offload enabled:\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\n\nmodel = load_model(\n    QwenImageDiT,\n    model_path,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config={\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\n```"
  },
  {
    "path": "docs/en/API_Reference/core/vram.md",
    "content": "# `diffsynth.core.vram`: VRAM Management\n\nThis document introduces the underlying VRAM management functionalities in `diffsynth.core.vram`. If you wish to use these functionalities in other codebases, you can refer to this document.\n\n## Skipping Model Parameter Initialization\n\nWhen loading models in `PyTorch`, model parameters default to occupying VRAM or memory and initializing parameters, but these parameters will be overwritten when loading pretrained weights, leading to redundant computations. `PyTorch` does not provide an interface to skip these redundant computations. We provide `skip_model_initialization` in `diffsynth.core.vram` to skip model parameter initialization.\n\nDefault model loading approach:\n\n```python\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\n\nmodel = QwenImageBlockWiseControlNet() # Slow\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\")\nmodel.load_state_dict(state_dict, assign=True)\n```\n\nModel loading approach that skips parameter initialization:\n\n```python\nfrom diffsynth.core import load_state_dict, skip_model_initialization\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\n\nwith skip_model_initialization():\n    model = QwenImageBlockWiseControlNet() # Fast\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\")\nmodel.load_state_dict(state_dict, assign=True)\n```\n\nIn `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](../../Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.\n\n## State Dict Disk Mapping\n\nFor pretrained weight files of a model, if we only need to read a set of parameters rather than all parameters, State Dict Disk Mapping can accelerate this process. We provide `DiskMap` in `diffsynth.core.vram` for on-demand loading of model parameters.\n\nDefault weight loading approach:\n\n```python\nfrom diffsynth.core import load_state_dict\n\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\") # Slow\nprint(state_dict[\"img_in.weight\"])\n```\n\nUsing `DiskMap` to load only specific parameters:\n\n```python\nfrom diffsynth.core import DiskMap\n\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = DiskMap(path, device=\"cpu\") # Fast\nprint(state_dict[\"img_in.weight\"])\n```\n\n`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](../../Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.\n\n`DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.**\n\n## Replacable Modules for VRAM Management\n\nWhen `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](../../Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes)."
  },
  {
    "path": "docs/en/Developer_Guide/Building_a_Pipeline.md",
    "content": "# Building a Pipeline\n\nAfter [integrating the required models for the Pipeline](../Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.\n\nThe `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components:\n\n* `__init__`\n* `from_pretrained`\n* `__call__`\n* `units`\n* `model_fn`\n\n## `__init__`\n\nIn `__init__`, the `Pipeline` is initialized. Here is a simple implementation:\n\n```python\nimport torch\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model\n\nclass NewDiffSynthPipeline(BasePipeline):\n\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(device=device, torch_dtype=torch_dtype)\n        self.scheduler = FlowMatchScheduler()\n        self.text_encoder: XXX_Model = None\n        self.dit: YYY_Model = None\n        self.vae: ZZZ_Model = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            NewDiffSynthPipelineUnit_xxx(),\n            ...\n        ]\n        self.model_fn = model_fn_new\n```\n\nThis includes the following parts:\n\n* `scheduler`: Scheduler, used to control the coefficients in the iterative formula during inference, controlling the noise content at each step.\n* `text_encoder`, `dit`, `vae`: Models. Since [Latent Diffusion](https://arxiv.org/abs/2112.10752) was proposed, this three-stage model architecture has become the mainstream Diffusion model architecture. However, this is not immutable, and any number of models can be added to the `Pipeline`.\n* `in_iteration_models`: Iteration models. This tuple marks which models will be called during iteration.\n* `units`: Pre-processing units for model iteration. See [`units`](#units) for details.\n* `model_fn`: The `forward` function of the denoising model during iteration. See [`model_fn`](#model_fn) for details.\n\n> Q: Model loading does not occur in `__init__`, why initialize each model as `None` here?\n> \n> A: By annotating the type of each model here, the code editor can provide code completion prompts based on each model, facilitating subsequent development.\n\n## `from_pretrained`\n\n`from_pretrained` is responsible for loading the required models to make the `Pipeline` callable. Here is a simple implementation:\n\n```python\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"xxx_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"yyy_dit\")\n        pipe.vae = model_pool.fetch_model(\"zzz_vae\")\n        # If necessary, load tokenizers here.\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n```\n\nDevelopers need to implement the logic for fetching models. The corresponding model names are the `\"model_name\"` in the [model Config filled in during model integration](../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).\n\nSome models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models.\n\n## `__call__`\n\n`__call__` implements the entire generation process of the Pipeline. Below is a common generation process template. Developers can modify it based on their needs.\n\n```python\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 4.0,\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        height: int = 1328,\n        width: int = 1328,\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        num_inference_steps: int = 30,\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(\n            num_inference_steps,\n            denoising_strength=denoising_strength\n        )\n        \n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image,\n            \"denoising_strength\": denoising_strength,\n            \"height\": height,\n            \"width\": width,\n            \"seed\": seed,\n            \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n\n            # Inference\n            noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)\n            if cfg_scale != 1.0:\n                noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)\n                noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)\n            else:\n                noise_pred = noise_pred_posi\n\n            # Scheduler\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"], device=self.device)\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n```\n\n## `units`\n\n`units` contains all the preprocessing processes, such as: width/height checking, prompt encoding, initial noise generation, etc. In the entire model preprocessing process, data is abstracted into three mutually exclusive parts, stored in corresponding dictionaries:\n\n* `inputs_shared`: Shared inputs, parameters unrelated to [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598) (CFG for short).\n* `inputs_posi`: Positive side inputs for Classifier-Free Guidance, containing content related to positive prompts.\n* `inputs_nega`: Negative side inputs for Classifier-Free Guidance, containing content related to negative prompts.\n\nPipeline Unit implementations include three types: direct mode, CFG separation mode, and takeover mode.\n\nIf some calculations are unrelated to CFG, direct mode can be used, for example, Qwen-Image's random noise initialization:\n\n```python\nclass QwenImageUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n```\n\nIf some calculations are related to CFG and need to separately process positive and negative prompts, but the input parameters on both sides are the same, CFG separation mode can be used, for example, Qwen-image's prompt encoding:\n\n```python\nclass QwenImageUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            input_params=(\"edit_image\",),\n            output_params=(\"prompt_emb\", \"prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:\n        pipe.load_models_to_device(self.onload_model_names)\n        # Do something\n        return {\"prompt_emb\": prompt_embeds, \"prompt_emb_mask\": encoder_attention_mask}\n```\n\nIf some calculations need global information, takeover mode is required, for example, Qwen-Image's entity partition control:\n\n```python\nclass QwenImageUnit_EntityControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"eligen_entity_prompts\", \"width\", \"height\", \"eligen_enable_on_negative\", \"cfg_scale\"),\n            output_params=(\"entity_prompt_emb\", \"entity_masks\", \"entity_prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        # Do something\n        return inputs_shared, inputs_posi, inputs_nega\n```\n\nThe following are the parameter configurations required for Pipeline Unit:\n\n* `seperate_cfg`: Whether to enable CFG separation mode\n* `take_over`: Whether to enable takeover mode\n* `input_params`: Shared input parameters\n* `output_params`: Output parameters\n* `input_params_posi`: Positive side input parameters\n* `input_params_nega`: Negative side input parameters\n* `onload_model_names`: Names of model components to be called\n\nWhen designing `unit`, please try to follow these principles:\n\n* Default fallback: For optional function `unit` input parameters, the default is `None` rather than `False` or other values. Please provide fallback processing for this default value.\n* Parameter triggering: Some Adapter models may not be loaded, such as ControlNet. The corresponding `unit` should control triggering based on whether the parameter input is `None` rather than whether the model is loaded. For example, when the user inputs `controlnet_image` but does not load the ControlNet model, the code should give an error rather than ignore these input parameters and continue execution.\n* Simplicity first: Use direct mode as much as possible, only use takeover mode when the function cannot be implemented.\n* VRAM efficiency: When calling models in `unit`, please use `pipe.load_models_to_device(self.onload_model_names)` to activate the corresponding models. Do not call other models outside `onload_model_names`. After `unit` calculation is completed, do not manually release VRAM with `pipe.load_models_to_device([])`.\n\n> Q: Some parameters are not called during the inference process, such as `output_params`. Is it still necessary to configure them?\n> \n> A: These parameters will not affect the inference process, but they will affect some experimental features. Therefore, we recommend configuring them properly. For example, \"split training\" - we can complete the preprocessing offline during training, but some model calculations that require gradient backpropagation cannot be split. These parameters are used to build computational graphs to infer which calculations can be split.\n\n## `model_fn`\n\n`model_fn` is the unified `forward` interface during iteration. For models where the open-source ecosystem is not yet formed, you can directly use the denoising model's `forward`, for example:\n\n```python\ndef model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):\n    return dit(latents, prompt_emb, timestep)\n```\n\nFor models with rich open-source ecosystems, `model_fn` usually contains complex and chaotic cross-model inference. Taking `diffsynth/pipelines/qwen_image.py` as an example, the additional calculations implemented in this function include: entity partition control, three types of ControlNet, Gradient Checkpointing, etc. Developers need to be extra careful when implementing this part to avoid conflicts between module functions."
  },
  {
    "path": "docs/en/Developer_Guide/Enabling_VRAM_management.md",
    "content": "# Fine-Grained VRAM Management Scheme\n\nThis document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).\n\n## How Much VRAM Does a 20B Model Need?\n\nTaking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.\n\n```python\nfrom diffsynth.core import load_model\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT\nfrom modelscope import snapshot_download\nimport torch\n\nsnapshot_download(\n    model_id=\"Qwen/Qwen-Image\",\n    local_dir=\"models/Qwen/Qwen-Image\",\n    allow_file_pattern=\"transformer/*\"\n)\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cuda\")\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\n## Writing Fine-Grained VRAM Management Scheme\n\nTo write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:\n\n```\nQwenImageDiT(\n  (pos_embed): QwenEmbedRope()\n  (time_text_embed): TimestepEmbeddings(\n    (time_proj): TemporalTimesteps()\n    (timestep_embedder): DiffusersCompatibleTimestepProj(\n      (linear_1): Linear(in_features=256, out_features=3072, bias=True)\n      (act): SiLU()\n      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)\n    )\n  )\n  (txt_norm): RMSNorm()\n  (img_in): Linear(in_features=64, out_features=3072, bias=True)\n  (txt_in): Linear(in_features=3584, out_features=3072, bias=True)\n  (transformer_blocks): ModuleList(\n    (0-59): 60 x QwenImageTransformerBlock(\n      (img_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (attn): QwenDoubleStreamAttention(\n        (to_q): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_k): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_v): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_q): RMSNorm()\n        (norm_k): RMSNorm()\n        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_added_q): RMSNorm()\n        (norm_added_k): RMSNorm()\n        (to_out): Sequential(\n          (0): Linear(in_features=3072, out_features=3072, bias=True)\n        )\n        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)\n      )\n      (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (img_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n      (txt_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n    )\n  )\n  (norm_out): AdaLayerNorm(\n    (linear): Linear(in_features=3072, out_features=6144, bias=True)\n    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n  )\n  (proj_out): Linear(in_features=3072, out_features=64, bias=True)\n)\n```\n\nIn VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.\n\n`diffsynth.core.vram` provides two replacement modules for VRAM management:\n* `AutoWrappedLinear`: Used to replace `Linear` layers\n* `AutoWrappedModule`: Used to replace any other layer\n\nWrite a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:\n\n```python\nmodule_map={\n    torch.nn.Linear: AutoWrappedLinear,\n    RMSNorm: AutoWrappedModule,\n}\n```\n\nIn addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).\n\nCall `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cpu\")\nenable_vram_management(\n    model,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config = {\n        \"offload_dtype\": torch.bfloat16,\n        \"offload_device\": \"cpu\",\n        \"onload_dtype\": torch.bfloat16,\n        \"onload_device\": \"cpu\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\nThe above code only requires 2G VRAM to run the `forward` of a 20B model.\n\n## Disk Offload\n\n[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(\n    QwenImageDiT,\n    model_path,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config={\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\nDisk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.\n\nIf there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.\n\n## Writing Default Configuration\n\nTo make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:\n\n```python\n\"diffsynth.models.qwen_image_dit.QwenImageDiT\": {\n    \"diffsynth.models.qwen_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n}\n```# Fine-Grained VRAM Management Scheme\n\nThis document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).\n\n## How Much VRAM Does a 20B Model Need?\n\nTaking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.\n\n```python\nfrom diffsynth.core import load_model\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT\nfrom modelscope import snapshot_download\nimport torch\n\nsnapshot_download(\n    model_id=\"Qwen/Qwen-Image\",\n    local_dir=\"models/Qwen/Qwen-Image\",\n    allow_file_pattern=\"transformer/*\"\n)\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cuda\")\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\n## Writing Fine-Grained VRAM Management Scheme\n\nTo write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:\n\n```\nQwenImageDiT(\n  (pos_embed): QwenEmbedRope()\n  (time_text_embed): TimestepEmbeddings(\n    (time_proj): TemporalTimesteps()\n    (timestep_embedder): DiffusersCompatibleTimestepProj(\n      (linear_1): Linear(in_features=256, out_features=3072, bias=True)\n      (act): SiLU()\n      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)\n    )\n  )\n  (txt_norm): RMSNorm()\n  (img_in): Linear(in_features=64, out_features=3072, bias=True)\n  (txt_in): Linear(in_features=3584, out_features=3072, bias=True)\n  (transformer_blocks): ModuleList(\n    (0-59): 60 x QwenImageTransformerBlock(\n      (img_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (attn): QwenDoubleStreamAttention(\n        (to_q): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_k): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_v): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_q): RMSNorm()\n        (norm_k): RMSNorm()\n        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_added_q): RMSNorm()\n        (norm_added_k): RMSNorm()\n        (to_out): Sequential(\n          (0): Linear(in_features=3072, out_features=3072, bias=True)\n        )\n        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)\n      )\n      (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (img_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n      (txt_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n    )\n  )\n  (norm_out): AdaLayerNorm(\n    (linear): Linear(in_features=3072, out_features=6144, bias=True)\n    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n  )\n  (proj_out): Linear(in_features=3072, out_features=64, bias=True)\n)\n```\n\nIn VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.\n\n`diffsynth.core.vram` provides two replacement modules for VRAM management:\n* `AutoWrappedLinear`: Used to replace `Linear` layers\n* `AutoWrappedModule`: Used to replace any other layer\n\nWrite a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:\n\n```python\nmodule_map={\n    torch.nn.Linear: AutoWrappedLinear,\n    RMSNorm: AutoWrappedModule,\n}\n```\n\nIn addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).\n\nCall `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cpu\")\nenable_vram_management(\n    model,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config = {\n        \"offload_dtype\": torch.bfloat16,\n        \"offload_device\": \"cpu\",\n        \"onload_dtype\": torch.bfloat16,\n        \"onload_device\": \"cpu\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\nThe above code only requires 2G VRAM to run the `forward` of a 20B model.\n\n## Disk Offload\n\n[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(\n    QwenImageDiT,\n    model_path,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config={\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\nDisk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.\n\nIf there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.\n\n## Writing Default Configuration\n\nTo make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:\n\n```python\n\"diffsynth.models.qwen_image_dit.QwenImageDiT\": {\n    \"diffsynth.models.qwen_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n}\n```"
  },
  {
    "path": "docs/en/Developer_Guide/Integrating_Your_Model.md",
    "content": "# Integrating Model Architecture\n\nThis document introduces how to integrate models into the `DiffSynth-Studio` framework for use by modules such as `Pipeline`.\n\n## Step 1: Integrate Model Architecture Code\n\nAll model architecture implementations in `DiffSynth-Studio` are unified in `diffsynth/models`. Each `.py` code file implements a model architecture, and all models are loaded through `ModelPool` in `diffsynth/models/model_loader.py`. When integrating new model architectures, please create a new `.py` file under this path.\n\n```shell\ndiffsynth/models/\n├── general_modules.py\n├── model_loader.py\n├── qwen_image_controlnet.py\n├── qwen_image_dit.py\n├── qwen_image_text_encoder.py\n├── qwen_image_vae.py\n└── ...\n```\n\nIn most cases, we recommend integrating models in native `PyTorch` code form, with the model architecture class directly inheriting from `torch.nn.Module`, for example:\n\n```python\nimport torch\n\nclass NewDiffSynthModel(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.linear = torch.nn.Linear(dim, dim)\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        x = self.linear(x)\n        x = self.activation(x)\n        return x\n```\n\nIf the model architecture implementation contains additional dependencies, we strongly recommend removing them, otherwise this will cause heavy package dependency issues. In our existing models, Qwen-Image's Blockwise ControlNet is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_controlnet.py`.\n\nIf the model has been integrated by Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index), [`diffusers`](https://huggingface.co/docs/diffusers/main/index), etc.), we can integrate the model in a simpler way:\n\n<details>\n<summary>Integrating Huggingface Library Style Model Architecture Code</summary>\n\nThe loading method for these models in Huggingface Library is:\n\n```python\nfrom transformers import XXX_Model\n\nmodel = XXX_Model.from_pretrained(\"path_to_your_model\")\n```\n\n`DiffSynth-Studio` does not support loading models through `from_pretrained` because this conflicts with VRAM management and other functions. Please rewrite the model architecture in the following format:\n\n```python\nimport torch\n\nclass DiffSynth_XXX_Model(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        from transformers import XXX_Config, XXX_Model\n        config = XXX_Config(**{\n            \"architectures\": [\"XXX_Model\"],\n            \"other_configs\": \"Please copy and paste the other configs here.\",\n        })\n        self.model = XXX_Model(config)\n        \n    def forward(self, x):\n        outputs = self.model(x)\n        return outputs\n```\n\nWhere `XXX_Config` is the Config class corresponding to the model. For example, the Config class for `Qwen2_5_VLModel` is `Qwen2_5_VLConfig`, which can be found by consulting its source code. The content inside Config can usually be found in the `config.json` file in the model library. `DiffSynth-Studio` will not read the `config.json` file, so the content needs to be copied and pasted into the code.\n\nIn rare cases, version updates of `transformers` and `diffusers` may cause some models to be unable to import. Therefore, if possible, we still recommend using the model integration method in Step 1.1.\n\nIn our existing models, Qwen-Image's Text Encoder is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_text_encoder.py`.\n\n</details>\n\n## Step 2: Model File Format Conversion\n\nDue to the variety of model file formats provided by developers in the open-source community, we sometimes need to convert model file formats to form correctly formatted [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html). This is common in the following situations:\n\n* Model files built by different code libraries, for example [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) and [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers).\n* Models modified during integration, for example, the Text Encoder of [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) adds a `model.` prefix in `diffsynth/models/qwen_image_text_encoder.py`.\n* Model files containing multiple models, for example, the VACE Adapter and base DiT model of [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) are mixed and stored in the same set of model files.\n\nIn our development philosophy, we hope to respect the wishes of model authors as much as possible. If we repackage the model files, for example [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI), although we can call the model more conveniently, traffic (model page views and downloads, etc.) will be directed elsewhere, and the original author of the model will also lose the power to delete the model. Therefore, we have added the `diffsynth/utils/state_dict_converters` module to the framework for file format conversion during model loading.\n\nThis part of logic is very simple. Taking Qwen-Image's Text Encoder as an example, only 10 lines of code are needed:\n\n```python\ndef QwenImageTextEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for k in state_dict:\n        v = state_dict[k]\n        if k.startswith(\"visual.\"):\n            k = \"model.\" + k\n        elif k.startswith(\"model.\"):\n            k = k.replace(\"model.\", \"model.language_model.\")\n        state_dict_[k] = v\n    return state_dict_\n```\n\n## Step 3: Writing Model Config\n\nModel Config is located in `diffsynth/configs/model_configs.py`, used to identify model types and load them. The following fields need to be filled in:\n\n* `model_hash`: Model file hash value, which can be obtained through the `hash_model_file` function. This hash value is only related to the keys and tensor shapes in the model file's state dict, and is unrelated to other information in the file.\n* `model_name`: Model name, used for `Pipeline` to identify the required model. If different structured models play the same role in `Pipeline`, the same `model_name` can be used. When integrating new models, just ensure that `model_name` is different from other existing functional models. The corresponding model is fetched through `model_name` in the `Pipeline`'s `from_pretrained`.\n* `model_class`: Model architecture import path, pointing to the model architecture class implemented in Step 1, for example `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`.\n* `state_dict_converter`: Optional parameter. If model file format conversion is needed, the import path of the model conversion logic needs to be filled in, for example `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`.\n* `extra_kwargs`: Optional parameter. If additional parameters need to be passed when initializing the model, these parameters need to be filled in. For example, models [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) and [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) both adopt the `QwenImageBlockWiseControlNet` structure in `diffsynth/models/qwen_image_controlnet.py`, but the latter also needs additional configuration `additional_in_dim=4`. Therefore, this configuration information needs to be filled in the `extra_kwargs` field.\n\nWe provide a piece of code to quickly understand how models are loaded through this configuration information:\n\n```python\nfrom diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization\nfrom diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder\nfrom diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter\nimport torch\n\nmodel_hash = \"8004730443f55db63092006dd9f7110e\"\nmodel_name = \"qwen_image_text_encoder\"\nmodel_class = QwenImageTextEncoder\nstate_dict_converter = QwenImageTextEncoderStateDictConverter\nextra_kwargs = {}\n\nmodel_path = [\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n]\nif hash_model_file(model_path) == model_hash:\n    with skip_model_initialization():\n        model = model_class(**extra_kwargs)\n    state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device=\"cuda\")\n    state_dict = state_dict_converter(state_dict)\n    model.load_state_dict(state_dict, assign=True)\n    print(\"Done!\")\n```\n\n> Q: The logic of the above code looks very simple, why is this part of code in `DiffSynth-Studio` extremely complex?\n> \n> A: Because we provide aggressive VRAM management functions that are coupled with the model loading logic, this leads to the complexity of the framework structure. We have tried our best to simplify the interface exposed to developers.\n\nThe `model_hash` in `diffsynth/configs/model_configs.py` is not uniquely existing. Multiple models may exist in the same model file. For this situation, please use multiple model Configs to load each model separately, and write the corresponding `state_dict_converter` to separate the parameters required by each model.\n\n## Step 4: Verifying Whether the Model Can Be Recognized and Loaded\n\nAfter model integration, the following code can be used to verify whether the model can be correctly recognized and loaded. The following code will attempt to load the model into memory:\n\n```python\nfrom diffsynth.models.model_loader import ModelPool\n\nmodel_pool = ModelPool()\nmodel_pool.auto_load_model(\n    [\n        \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n    ],\n)\n```\n\nIf the model can be recognized and loaded, you will see the following output:\n\n```\nLoading models from: [\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n]\nLoaded model: {\n    \"model_name\": \"qwen_image_text_encoder\",\n    \"model_class\": \"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\",\n    \"extra_kwargs\": null\n}\n```\n\n## Step 5: Writing Model VRAM Management Scheme\n\n`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](../Developer_Guide/Enabling_VRAM_management.md) for details."
  },
  {
    "path": "docs/en/Developer_Guide/Training_Diffusion_Models.md",
    "content": "# Integrating Model Training\n\nAfter [integrating models](../Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](../Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.\n\n## Training-Inference Consistent Pipeline Modification\n\nTo ensure strict consistency between training and inference processes, we will use most of the inference code during training, but still need to make minor modifications.\n\nFirst, add extra logic during inference to switch the image-to-image/video-to-video logic based on the `scheduler` state. Taking Qwen-Image as an example:\n\n```python\nclass QwenImageUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n        input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n```\n\nThen, enable Gradient Checkpointing in `model_fn`, which will significantly reduce the VRAM required for training at the cost of computational speed. This is not mandatory, but we strongly recommend doing so.\n\nTaking Qwen-Image as an example, before modification:\n\n```python\ntext, image = block(\n    image=image,\n    text=text,\n    temb=conditioning,\n    image_rotary_emb=image_rotary_emb,\n    attention_mask=attention_mask,\n)\n```\n\nAfter modification:\n\n```python\nfrom ..core import gradient_checkpoint_forward\n\ntext, image = gradient_checkpoint_forward(\n    block,\n    use_gradient_checkpointing,\n    use_gradient_checkpointing_offload,\n    image=image,\n    text=text,\n    temb=conditioning,\n    image_rotary_emb=image_rotary_emb,\n    attention_mask=attention_mask,\n)\n```\n\n## Writing Training Scripts\n\n`DiffSynth-Studio` does not strictly encapsulate the training framework, but exposes the script content to developers. This approach makes it more convenient to modify training scripts to implement additional functions. Developers can refer to existing training scripts, such as `examples/qwen_image/model_training/train.py`, for modification to adapt to new model training."
  },
  {
    "path": "docs/en/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/en/Model_Details/Anima.md",
    "content": "# Anima\n\nAnima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nThe following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required.\n\n```python\nfrom diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n## Model Overview\n\n|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|\n|-|-|-|-|-|-|-|\n|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|\n\nSpecial training scripts:\n\n* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md)\n* Two-Stage Split Training: [doc](../Training/Split_Training.md)\n* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md)\n\n## Model Inference\n\nModels are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details.\n\nInput parameters for `AnimaImagePipeline` inference include:\n\n* `prompt`: Text description of the desired image content.\n* `negative_prompt`: Content to exclude from the generated image (default: `\"\"`).\n* `cfg_scale`: Classifier-free guidance parameter (default: 4.0).\n* `input_image`: Input image for image-to-image generation (default: `None`).\n* `denoising_strength`: Controls similarity to input image (default: 1.0).\n* `height`: Image height (must be multiple of 16, default: 1024).\n* `width`: Image width (must be multiple of 16, default: 1024).\n* `seed`: Random seed (default: `None`).\n* `rand_device`: Device for random noise generation (default: `\"cpu\"`).\n* `num_inference_steps`: Inference steps (default: 30).\n* `sigma_shift`: Scheduler sigma offset (default: `None`).\n* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`).\n\nFor VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the \"Model Overview\" table above.\n\n## Model Training\n\nAnima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including:\n\n* General Training Parameters\n    * Dataset Configuration\n        * `--dataset_base_path`: Dataset root directory.\n        * `--dataset_metadata_path`: Metadata file path.\n        * `--dataset_repeat`: Dataset repetition per epoch.\n        * `--dataset_num_workers`: Dataloader worker count.\n        * `--data_file_keys`: Metadata fields to load (comma-separated).\n    * Model Loading\n        * `--model_paths`: Model paths (JSON format).\n        * `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `\"anima-team/anima-1B:text_encoder/*.safetensors\"`).\n        * `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet).\n        * `--fp8_models`: FP8-formatted models (same format as `--model_paths`).\n    * Training Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Training epochs.\n        * `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`).\n        * `--find_unused_parameters`: Handle unused parameters in DDP training.\n        * `--weight_decay`: Weight decay value.\n        * `--task`: Training task (default: `sft`).\n    * Output Configuration\n        * `--output_path`: Model output directory.\n        * `--remove_prefix_in_ckpt`: Remove state dict prefixes.\n        * `--save_steps`: Model saving interval.\n    * LoRA Configuration\n        * `--lora_base_model`: Target model for LoRA.\n        * `--lora_target_modules`: Target modules for LoRA.\n        * `--lora_rank`: LoRA rank.\n        * `--lora_checkpoint`: LoRA checkpoint path.\n        * `--preset_lora_path`: Preloaded LoRA checkpoint path.\n        * `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`).\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU.\n        * `--gradient_accumulation_steps`: Gradient accumulation steps.\n    * Image Resolution\n        * `--height`: Image height (empty for dynamic resolution).\n        * `--width`: Image width (empty for dynamic resolution).\n        * `--max_pixels`: Maximum pixel area for dynamic resolution.\n* Anima-Specific Parameters\n    * `--tokenizer_path`: Tokenizer path for text-to-image models.\n    * `--tokenizer_t5xxl_path`: T5-XXL tokenizer path.\n\nWe provide a sample image dataset for testing:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nFor training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)."
  },
  {
    "path": "docs/en/Model_Details/FLUX.md",
    "content": "# FLUX\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\nFLUX is an image generation model series developed and open-sourced by Black Forest Labs.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 1,\n)\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")\n```\n\n## Model Overview\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - | - |\n| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |\n| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |\n| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |\n| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |\n| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |\n| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |\n| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |\n| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |\n| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |\n| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |\n| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |\n| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |\n| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |\n| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |\n\nSpecial Training Scripts:\n\n* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md)\n* Two-stage Split Training: [doc](../Training/Split_Training.md)\n* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)\n\n## Model Inference\n\nModels are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).\n\nInput parameters for `FluxImagePipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the image.\n* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled.\n* `height`: Image height, must be a multiple of 16.\n* `width`: Image width, must be a multiple of 16.\n* `seed`: Random seed. Default is `None`, meaning completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different GPUs will produce different generation results.\n* `num_inference_steps`: Number of inference steps, default value is 30.\n* `embedded_guidance`: Embedded guidance parameter, default value is 3.5.\n* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512.\n* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.\n* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.\n* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.\n* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.\n* `controlnet_inputs`: ControlNet model inputs, type is `ControlNetInput` list.\n* `ipadapter_images`: IP-Adapter model input image list.\n* `ipadapter_scale`: Guidance strength of the IP-Adapter model.\n* `infinityou_id_image`: InfiniteYou model input image.\n* `infinityou_guidance`: Guidance strength of the InfiniteYou model.\n* `kontext_images`: Kontext model input images.\n* `eligen_entity_prompts`: EliGen partition control prompt list.\n* `eligen_entity_masks`: EliGen partition control region mask image list.\n* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG.\n* `eligen_enable_inpaint`: Whether to enable EliGen partition control inpainting function.\n* `lora_encoder_inputs`: LoRA encoder input image list.\n* `lora_encoder_scale`: Guidance strength of the LoRA encoder.\n* `step1x_reference_image`: Step1X model reference image.\n* `flex_inpaint_image`: Flex model image to be inpainted.\n* `flex_inpaint_mask`: Flex model inpainting mask.\n* `flex_control_image`: Flex model control image.\n* `flex_control_strength`: Flex model control strength.\n* `flex_control_stop`: Flex model control stop timestep.\n* `nexus_gen_reference_image`: Nexus-Gen model reference image.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the \"Model Overview\" section above.\n\n## Model Training\n\nFLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)\n        * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.\n* FLUX Specific Parameters\n    * `--tokenizer_1_path`: Path of the CLIP tokenizer, leave blank to automatically download from remote.\n    * `--tokenizer_2_path`: Path of the T5 tokenizer, leave blank to automatically download from remote.\n    * `--align_to_opensource_format`: Whether to align LoRA format to open-source format, only applicable to DiT's LoRA.\n\nWe have built a sample image dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n"
  },
  {
    "path": "docs/en/Model_Details/FLUX2.md",
    "content": "# FLUX.2\n\nFLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.\n\n## Model Lineage\n\n```mermaid\ngraph LR;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-dev;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;\n```\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 10GB VRAM is required to run.\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n## Model Overview\n\n| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - |\n|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|\n|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|\n|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|\n|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|\n|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|\n\nSpecial Training Scripts:\n\n* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md)\n* Two-stage Split Training: [doc](../Training/Split_Training.md)\n* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)\n\n## Model Inference\n\nModels are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).\n\nInput parameters for `Flux2ImagePipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the image.\n* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled.\n* `height`: Image height, must be a multiple of 16.\n* `width`: Image width, must be a multiple of 16.\n* `seed`: Random seed. Default is `None`, meaning completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different GPUs will produce different generation results.\n* `num_inference_steps`: Number of inference steps, default value is 30.\n* `embedded_guidance`: Embedded guidance parameter, default value is 3.5.\n* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512.\n* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.\n* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.\n* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.\n* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the \"Model Overview\" section above.\n\n## Model Training\n\nFLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)\n        * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.\n* FLUX.2 Specific Parameters\n    * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.\n\nWe have built a sample image dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n"
  },
  {
    "path": "docs/en/Model_Details/LTX-2.md",
    "content": "# LTX-2\n\nLTX-2 is a series of audio-video generation models developed by Lightricks.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Installation Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.\n\n```python\nimport torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\n\nprompt = \"A girl is very happy, she is speaking: \\\"I enjoy working with Diffsynth-Studio, it's a perfect framework.\\\"\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n```\n\n## Model Overview\n|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|\n|-|-|-|-|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|\n\n## Model Inference\n\nModels are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.\n\nInput parameters for `LTX2AudioVideoPipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the video.\n* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.\n* `input_images`: List of input images for image-to-video generation.\n* `input_images_indexes`: Frame index list of input images in the video.\n* `input_images_strength`: Strength of input images, default value is 1.0.\n* `denoising_strength`: Denoising strength, range is 0～1, default value is 1.0.\n* `seed`: Random seed. Default is `None`, which means completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different results will be generated on different GPUs.\n* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).\n* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).\n* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.\n* `num_inference_steps`: Number of inference steps, default value is 40.\n* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.\n* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.\n* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.\n* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.\n* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.\n* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.\n* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.\n* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous \"Supported Inference Scripts\" section.\n\n## Model Training\n\nLTX-2 series models are uniformly trained through [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Video Width/Height Configuration\n        * `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.\n        * `--num_frames`: Number of frames in the video.\n* LTX-2 Series Specific Parameters\n    * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.\n    * `--frame_rate`: frame rate of the training videos.\n\nWe have built a sample video dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n"
  },
  {
    "path": "docs/en/Model_Details/Overview.md",
    "content": "# Model Directory\n\n## Qwen-Image\n\nDocumentation: [./Qwen-Image.md](../Model_Details/Qwen-Image.md)\n\n<details>\n\n<summary>Effect Preview</summary>\n\n![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)\n\n</details>\n\n<details>\n\n<summary>Quick Start</summary>\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0, num_inference_steps=40,\n    # edit_image=Image.open(\"xxx.jpg\").resize((1328, 1328)) # For Qwen-Image-Edit\n)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - |\n| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |\n| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |\n| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |\n| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |\n| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |\n| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |\n| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |\n| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |\n| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |\n| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |\n\n## FLUX Series\n\nDocumentation: [./FLUX.md](../Model_Details/FLUX.md)\n\n<details>\n\n<summary>Effect Preview</summary>\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\n</details>\n\n<details>\n\n<summary>Quick Start</summary>\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nimage = pipe(prompt=\"a cat\", seed=0)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - | - |\n| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |\n| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |\n| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |\n| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |\n| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |\n| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |\n| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |\n| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |\n| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |\n| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |\n| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |\n| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |\n| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |\n| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |\n\n## Wan Series\n\nDocumentation: [./Wan.md](../Model_Details/Wan.md)\n\n<details>\n\n<summary>Effect Preview</summary>\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\n</details>\n\n<details>\n\n<summary>Quick Start</summary>\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n</details>\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - |\n| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |\n| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |\n| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |\n| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |\n| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |\n| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |\n| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |\n| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |\n| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |\n| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |\n| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |\n| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |\n| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |\n| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |\n| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |\n| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |\n| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |\n| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |\n| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |\n| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |\n| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |\n| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |\n| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |\n| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |\n| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |\n| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |\n| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |\n| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |\n| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |\n| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |\n| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |\n\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)\n* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)\n* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)\n"
  },
  {
    "path": "docs/en/Model_Details/Qwen-Image.md",
    "content": "# Qwen-Image\n\n![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)\n\nQwen-Image is an image generation model trained and open-sourced by the Tongyi Lab Qwen Team of Alibaba.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## Model Overview\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n| - | - | - | - | - | - | - |\n| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |\n|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|\n| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |\n| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |\n|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|\n|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|\n|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|\n| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |\n| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |\n| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |\n| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |\n| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |\n| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |\n| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |\n| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |\n|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|\n\nSpecial Training Scripts:\n\n* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/)\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/)\n* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/)\n* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)\n\nDeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:\n\n* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`\n* `--initialize_model_on_cpu`\n\n## Model Inference\n\nModels are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).\n\nInput parameters for `QwenImagePipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the image.\n* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 4. When set to 1, it no longer takes effect.\n* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`.\n* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value.\n* `inpaint_mask`: Image inpainting mask image.\n* `inpaint_blur_size`: Edge blur width for image inpainting.\n* `inpaint_blur_sigma`: Edge blur strength for image inpainting.\n* `height`: Image height, must be a multiple of 16.\n* `width`: Image width, must be a multiple of 16.\n* `seed`: Random seed. Default is `None`, meaning completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different GPUs will produce different generation results.\n* `num_inference_steps`: Number of inference steps, default value is 30.\n* `exponential_shift_mu`: Fixed parameter used in sampling timesteps. Leave blank to sample based on image width and height.\n* `blockwise_controlnet_inputs`: Blockwise ControlNet model inputs.\n* `eligen_entity_prompts`: EliGen partition control prompts.\n* `eligen_entity_masks`: EliGen partition control region mask images.\n* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG.\n* `edit_image`: Edit model images to be edited, supports multiple images.\n* `edit_image_auto_resize`: Whether to automatically scale edit images.\n* `edit_rope_interpolation`: Whether to enable ROPE interpolation on low-resolution edit images.\n* `context_image`: In-Context Control input image.\n* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.\n* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.\n* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.\n* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the \"Model Overview\" section above.\n\n## Model Training\n\nQwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters `edit_image` when training image editing model Qwen-Image-Edit, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)\n        * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.\n* Qwen-Image Specific Parameters\n    * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.\n    * `--processor_path`: Path of the processor, applicable to image editing models, leave blank to automatically download from remote.\n\nWe have built a sample image dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n"
  },
  {
    "path": "docs/en/Model_Details/Wan.md",
    "content": "# Wan\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\nWan is a video generation model series developed by the Tongyi Wanxiang Team of Alibaba Tongyi Lab.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n## Model Overview\n\n<details>\n\n<summary>Model Lineage</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n| Model ID | Extra Inputs | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |\n|-|-|-|-|-|-|-|-|\n|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|\n|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|\n|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|\n|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|\n|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|\n|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|\n|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|\n|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|\n|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|\n|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|\n|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|\n|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|\n|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|\n|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|\n|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|\n|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|\n|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|\n|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|\n|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|\n|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|\n|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|\n|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|\n|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|\n|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|\n|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|\n|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|\n|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|\n\n* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)\n* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)\n* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)\n\nDeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required:\n\n* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml`\n* `--initialize_model_on_cpu`\n\n## Model Inference\n\nModels are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).\n\nInput parameters for `WanVideoPipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the video.\n* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 5. When set to 1, it no longer takes effect.\n* `input_image`: Input image for image-to-video generation, used in conjunction with `denoising_strength`.\n* `end_image`: End image for first-and-last frame video generation.\n* `input_video`: Input video for video-to-video generation, used in conjunction with `denoising_strength`.\n* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated video is similar to the input video; when the value approaches 1, the generated video differs more from the input video.\n* `control_video`: Control video for controlling the video generation process.\n* `reference_image`: Reference image for maintaining consistency of certain features in the generated video.\n* `camera_control_direction`: Camera control direction, optional values are `\"Left\"`, `\"Right\"`, `\"Up\"`, `\"Down\"`, `\"LeftUp\"`, `\"LeftDown\"`, `\"RightUp\"`, `\"RightDown\"`.\n* `camera_control_speed`: Camera control speed, default value is 1/54.\n* `vace_video`: VACE control video.\n* `vace_video_mask`: VACE control video mask.\n* `vace_reference_image`: VACE reference image.\n* `vace_scale`: VACE control strength, default value is 1.0.\n* `animate_pose_video`: `animate` model pose video.\n* `animate_face_video`: `animate` model face video.\n* `animate_inpaint_video`: `animate` model local editing video.\n* `animate_mask_video`: `animate` model mask video.\n* `vap_video`: `video-as-prompt` input video.\n* `vap_prompt`: `video-as-prompt` text description.\n* `negative_vap_prompt`: `video-as-prompt` negative text description.\n* `input_audio`: Input audio for speech-to-video generation.\n* `audio_embeds`: Audio embedding vectors.\n* `audio_sample_rate`: Audio sampling rate, default value is 16000.\n* `s2v_pose_video`: S2V model pose video.\n* `motion_video`: S2V model motion video.\n* `height`: Video height, must be a multiple of 16.\n* `width`: Video width, must be a multiple of 16.\n* `num_frames`: Number of video frames, default value is 81, must be a multiple of 4 + 1.\n* `seed`: Random seed. Default is `None`, meaning completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different GPUs will produce different generation results.\n* `num_inference_steps`: Number of inference steps, default value is 50.\n* `motion_bucket_id`: Motion control parameter, the larger the value, the greater the motion amplitude.\n* `longcat_video`: LongCat input video.\n* `tiled`: Whether to enable VAE tiling inference, default is `True`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.\n* `tile_size`: Tile size during VAE encoding/decoding stages, default is `(30, 52)`, only effective when `tiled=True`.\n* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is `(15, 26)`, only effective when `tiled=True`, must be less than or equal to `tile_size`.\n* `switch_DiT_boundary`: Time boundary for switching DiT models, default value is 0.875.\n* `sigma_shift`: Timestep offset parameter, default value is 5.0.\n* `sliding_window_size`: Sliding window size.\n* `sliding_window_stride`: Sliding window stride.\n* `tea_cache_l1_thresh`: L1 threshold for TeaCache.\n* `tea_cache_model_id`: Model ID used by TeaCache.\n* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the \"Model Overview\" section above.\n\n### Multi-GPU Parallel Acceleration\n\nTo enable multi-GPU parallel acceleration, please install `flash_attn` and `xfuser`:\n\n```shell\npip install flash-attn --no-build-isolation\npip install xfuser\n```\n\nPlease modify your code as follows ([example code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)):\n\n```diff\nimport torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n+ import torch.distributed as dist\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n+   use_usp=True,\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\nvideo = pipe(\n    prompt=\"An astronaut in a spacesuit rides a mechanical horse across the Martian surface, facing the camera. The red, desolate terrain stretches into the distance, dotted with massive craters and unusual rock formations. The mechanical horse moves with steady strides, kicking up faint dust, embodying a perfect fusion of futuristic technology and primal exploration. The astronaut holds a control device, with a determined gaze, as if pioneering new frontiers for humanity. Against a backdrop of the deep cosmos and the blue Earth, the scene is both sci-fi and hopeful, evoking imagination about future interstellar life.\",\n    negative_prompt=\"oversaturated colors, overexposed, static, blurry details, subtitles, style, artwork, painting, still image, overall gray tone, worst quality, low quality, JPEG compression artifacts, ugly, malformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, frozen frame, cluttered background, three legs, crowd in background, walking backwards\",\n    seed=0, tiled=True,\n)\n+ if dist.get_rank() == 0:\n+   save_video(video, \"video1.mp4\", fps=15, quality=5)\n```\n\nWhen running multi-GPU parallel inference, please use `torchrun`, where `--nproc_per_node` specifies the number of GPUs:\n\n```shell\ntorchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py\n```\n\n## Model Training\n\nWan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Video Width/Height Configuration\n        * `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.\n        * `--num_frames`: Number of frames in the video.\n* Wan Series Specific Parameters\n    * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.\n    * `--audio_processor_path`: Path of the audio processor, applicable to speech-to-video models, leave blank to automatically download from remote.\n\nWe have built a sample video dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n"
  },
  {
    "path": "docs/en/Model_Details/Z-Image.md",
    "content": "# Z-Image\n\nZ-Image is an image generation model trained and open-sourced by the Multimodal Interaction Team of Alibaba Tongyi Lab.\n\n## Installation\n\nBefore using this project for model inference and training, please install DiffSynth-Studio first.\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nFor more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).\n\n## Quick Start\n\nRun the following code to quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model and perform inference. FP8 precision quantization causes noticeable image quality degradation, so it is not recommended to enable any quantization on the Z-Image Turbo model. Only CPU Offload is recommended, minimum 8GB VRAM is required to run.\n\n```python\nfrom diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n```\n\n## Model Overview\n\n|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|\n|-|-|-|-|-|-|-|\n|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)|\n|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|\n|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|\n\nSpecial Training Scripts:\n\n* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)\n* Trajectory Imitation Distillation Training (Experimental Feature): [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)\n\n## Model Inference\n\nModels are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).\n\nInput parameters for `ZImagePipeline` inference include:\n\n* `prompt`: Prompt describing the content appearing in the image.\n* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `\"\"`.\n* `cfg_scale`: Classifier-free guidance parameter, default value is 1.\n* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`.\n* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value.\n* `height`: Image height, must be a multiple of 16.\n* `width`: Image width, must be a multiple of 16.\n* `seed`: Random seed. Default is `None`, meaning completely random.\n* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `\"cpu\"`. When set to `cuda`, different GPUs will produce different generation results.\n* `num_inference_steps`: Number of inference steps, default value is 8.\n* `controlnet_inputs`: Inputs for ControlNet models.\n* `edit_image`: Edit images for image editing models, supporting multiple images.\n* `positive_only_lora`: LoRA weights used only in positive prompts.\n\nIf VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the \"Model Overview\" section above.\n\n## Model Training\n\nZ-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py), and the script parameters include:\n\n* General Training Parameters\n    * Dataset Basic Configuration\n        * `--dataset_base_path`: Root directory of the dataset.\n        * `--dataset_metadata_path`: Metadata file path of the dataset.\n        * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n        * `--dataset_num_workers`: Number of processes for each DataLoader.\n        * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.\n    * Model Loading Configuration\n        * `--model_paths`: Paths of models to be loaded. JSON format.\n        * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `\"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors\"`. Separated by commas.\n        * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.\n        * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n    * Training Basic Configuration\n        * `--learning_rate`: Learning rate.\n        * `--num_epochs`: Number of epochs.\n        * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.\n        * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n        * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\n        * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.\n    * Output Configuration\n        * `--output_path`: Model saving path.\n        * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.\n        * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.\n    * LoRA Configuration\n        * `--lora_base_model`: Which model to add LoRA to.\n        * `--lora_target_modules`: Which layers to add LoRA to.\n        * `--lora_rank`: Rank of LoRA.\n        * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n        * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n        * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.\n    * Gradient Configuration\n        * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n        * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n        * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n    * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)\n        * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.\n        * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.\n* Z-Image Specific Parameters\n    * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.\n\nWe have built a sample image dataset for your testing. You can download this dataset with the following command:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nWe have written recommended training scripts for each model, please refer to the table in the \"Model Overview\" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).\n\nTraining Tips:\n\n* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with \"acceleration configuration\" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with \"no acceleration configuration\" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted:\n    * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference\n    * Differential LoRA Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference\n        * An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)\n    * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference\n    * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference\n"
  },
  {
    "path": "docs/en/Pipeline_Usage/Environment_Variables.md",
    "content": "# Environment Variables\n\n`DiffSynth-Studio` can control some settings through environment variables.\n\nIn `Python` code, you can set environment variables using `os.environ`. Please note that environment variables must be set before `import diffsynth`.\n\n```python\nimport os\nos.environ[\"DIFFSYNTH_MODEL_BASE_PATH\"] = \"./path_to_my_models\"\nimport diffsynth\n```\n\nOn Linux operating systems, you can also temporarily set environment variables from the command line:\n\n```shell\nDIFFSYNTH_MODEL_BASE_PATH=\"./path_to_my_models\" python xxx.py\n```\n\nBelow are the environment variables supported by `DiffSynth-Studio`.\n\n## `DIFFSYNTH_SKIP_DOWNLOAD`\n\nWhether to skip model downloads. Can be set to `True`, `true`, `False`, `false`. If `skip_download` is not set in `ModelConfig`, this environment variable will determine whether to skip model downloads.\n\n## `DIFFSYNTH_MODEL_BASE_PATH`\n\nModel download root directory. Can be set to any local path. If `local_model_path` is not set in `ModelConfig`, model files will be downloaded to the path pointed to by this environment variable. If neither is set, model files will be downloaded to `./models`.\n\n## `DIFFSYNTH_ATTENTION_IMPLEMENTATION`\n\nAttention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](../API_Reference/core/attention.md) for details.\n\n## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE`\n\nBuffer size in disk mapping. Default is 1B (1000000000). Larger values occupy more memory but result in faster speeds.\n\n## `DIFFSYNTH_DOWNLOAD_SOURCE`\n\nRemote model download source. Can be set to `modelscope` or `huggingface` to control the source of model downloads. Default value is `modelscope`."
  },
  {
    "path": "docs/en/Pipeline_Usage/GPU_support.md",
    "content": "# GPU/NPU Support\n\n`DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices.\n\nBefore you begin, please follow the [Installation Guide](../Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies.\n\n## NVIDIA GPU\n\nAll sample code provided by this project supports NVIDIA GPUs by default, requiring no additional modifications.\n\n## AMD GPU\n\nAMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.\n\n## Ascend NPU\n### Inference\nWhen using Ascend NPU, you need to replace `\"cuda\"` with `\"npu\"` in your code.\n\nFor example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:\n\n```diff\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom diffsynth.core.device.npu_compatible_device import get_device_name\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n-   \"preparing_device\": \"cuda\",\n+   \"preparing_device\": \"npu\",\n    \"computation_dtype\": torch.bfloat16,\n-   \"computation_device\": \"cuda\",\n+   \"computation_device\": \"npu\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n-   device=\"cuda\",\n+   device=\"npu\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n-   vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n+   vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"Documentary-style photography: a lively puppy running swiftly across lush green grass. The puppy has brownish-yellow fur, upright ears, and an alert, joyful expression. Sunlight bathes its body, making the fur appear exceptionally soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky with scattered white clouds in the distance. Strong perspective captures the motion of the running puppy and the vitality of the surrounding grass. Mid-shot, side-moving viewpoint.\",\n    negative_prompt=\"Overly vibrant colors, overexposed, static, blurry details, subtitles, artistic style, painting, still image, overall grayish tone, worst quality, low quality, JPEG artifacts, ugly, distorted, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, fused fingers, motionless scene, cluttered background, three legs, many people in background, walking backward\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n#### USP(Unified Sequence Parallel)\nIf you want to use this feature on NPU, please install additional third-party libraries as follows:\n```shell\npip install git+https://github.com/feifeibear/long-context-attention.git\npip install git+https://github.com/xdit-project/xDiT.git\n```\n\n\n### Training\nNPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.\n\nIn the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.\n\n#### Environment variables\n```shell\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n```\n`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.\n\n```shell\nexport CPU_AFFINITY_CONF=1\n```\nSet 0 or not set: indicates not enabling the binding function\n\n1: Indicates enabling coarse-grained kernel binding\n\n2: Indicates enabling fine-grained kernel binding\n\n#### Parameters for specific models\n| Model          | Parameter                 | Note              |\n|----------------|---------------------------|-------------------|\n| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |\n| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU |\n| Z-Image series | --enable_npu_patch | Using NPU fusion operator to replace the corresponding operator in Z-image model to improve the performance of the model on NPU |"
  },
  {
    "path": "docs/en/Pipeline_Usage/Model_Inference.md",
    "content": "# Model Inference\n\nThis document uses the Qwen-Image model as an example to introduce how to use `DiffSynth-Studio` for model inference.\n\n## Loading Models\n\nModels are loaded through `from_pretrained`:\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n```\n\nWhere `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](../API_Reference/core/loader.md).\n\n<details>\n\n<summary>Download and load models from remote sources</summary>\n\n> `DiffSynth-Studio` downloads and loads models from [ModelScope](https://www.modelscope.cn/) by default. You need to fill in `model_id` and `origin_file_pattern`, for example:\n> \n> ```python\n> ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n> ```\n> \n> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).\n\n</details>\n\n<details>\n\n<summary>Load models from local file paths</summary>\n\n> Fill in `path`, for example:\n> \n> ```python\n> ModelConfig(path=\"models/xxx.safetensors\")\n> ```\n> \n> For models loaded from multiple files, use a list, for example:\n> \n> ```python\n> ModelConfig(path=[\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n> ])\n> ```\n\n</details>\n\nBy default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.\n\n```shell\nimport os\nos.environ[\"DIFFSYNTH_SKIP_DOWNLOAD\"] = \"True\"\nimport diffsynth\n```\n\nTo download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](../Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`.\n\n```shell\nimport os\nos.environ[\"DIFFSYNTH_DOWNLOAD_SOURCE\"] = \"huggingface\"\nimport diffsynth\n```\n\n## Starting Inference\n\nInput a prompt to start the inference process and generate an image.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\nEach model `Pipeline` has different input parameters. Please refer to the documentation for each model.\n\nIf the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](../Pipeline_Usage/VRAM_management.md).\n\n## Loading LoRA\n\nLoRA is a lightweight model training method that produces a small number of parameters to extend model capabilities. DiffSynth-Studio supports two ways to load LoRA: cold loading and hot loading.\n\n* Cold loading: When the base model does not have [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, but LoRA cannot be unloaded after loading.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nlora = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1\", origin_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, lora, alpha=1)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n* Hot loading: When the base model has [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will not be fused into the base model weights. In this case, inference speed will be slower, but LoRA can be unloaded through `pipe.clear_lora()` after loading.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cuda\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nlora = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1\", origin_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, lora, alpha=1)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\npipe.clear_lora()\n```\n"
  },
  {
    "path": "docs/en/Pipeline_Usage/Model_Training.md",
    "content": "# Model Training\n\nThis document introduces how to use `DiffSynth-Studio` for model training.\n\n## Script Parameters\n\nTraining scripts typically include the following parameters:\n\n* Dataset base configuration\n    * `--dataset_base_path`: Root directory of the dataset.\n    * `--dataset_metadata_path`: Metadata file path of the dataset.\n    * `--dataset_repeat`: Number of times the dataset is repeated in each epoch.\n    * `--dataset_num_workers`: Number of processes for each Dataloader.\n    * `--data_file_keys`: Field names that need to be loaded from metadata, usually image or video file paths, separated by `,`.\n* Model loading configuration\n    * `--model_paths`: Paths of models to be loaded. JSON format.\n    * `--model_id_with_origin_paths`: Model IDs with original paths, for example `\"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\"`. Separated by commas.\n    * `--extra_inputs`: Extra input parameters required by the model Pipeline, for example, training image editing model Qwen-Image-Edit requires extra parameter `edit_image`, separated by `,`.\n    * `--fp8_models`: Models loaded in FP8 format, consistent with the format of `--model_paths` or `--model_id_with_origin_paths`. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).\n* Training base configuration\n    * `--learning_rate`: Learning rate.\n    * `--num_epochs`: Number of epochs.\n    * `--trainable_models`: Trainable models, for example `dit`, `vae`, `text_encoder`.\n    * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.\n    * `--weight_decay`: Weight decay size. See [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html) for details.\n    * `--task`: Training task, default is `sft`. Some models support more training modes. Please refer to the documentation for each specific model.\n* Output configuration\n    * `--output_path`: Model save path.\n    * `--remove_prefix_in_ckpt`: Remove prefixes in the state dict of model files.\n    * `--save_steps`: Interval of training steps for saving models. If this parameter is left blank, the model will be saved once per epoch.\n* LoRA configuration\n    * `--lora_base_model`: Which model LoRA is added to.\n    * `--lora_target_modules`: Which layers LoRA is added to.\n    * `--lora_rank`: Rank of LoRA.\n    * `--lora_checkpoint`: Path of LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.\n    * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.\n    * `--preset_lora_model`: Model that preset LoRA is merged into, for example `dit`.\n* Gradient configuration\n    * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.\n    * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.\n    * `--gradient_accumulation_steps`: Number of gradient accumulation steps.\n* Image dimension configuration (applicable to image generation models and video generation models)\n    * `--height`: Height of images or videos. Leave `height` and `width` blank to enable dynamic resolution.\n    * `--width`: Width of images or videos. Leave `height` and `width` blank to enable dynamic resolution.\n    * `--max_pixels`: Maximum pixel area of images or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be scaled down, and images with resolution smaller than this value will remain unchanged.\n\nSome models' training scripts also contain additional parameters. See the documentation for each model for details.\n\n## Preparing Datasets\n\n`DiffSynth-Studio` adopts a universal dataset format. The dataset contains a series of data files (images, videos, etc.) and annotated metadata files. We recommend organizing dataset files as follows:\n\n```\ndata/example_image_dataset/\n├── metadata.csv\n├── image_1.jpg\n└── image_2.jpg\n```\n\nWhere `image_1.jpg`, `image_2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:\n\n```\nimage,prompt\nimage_1.jpg,\"a dog\"\nimage_2.jpg,\"a cat\"\n```\n\nWe have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](../API_Reference/core/data.md).\n\n<details>\n\n<summary>Sample Dataset</summary>\n\n> ```shell\n> modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n> ```\n\n</details>\n\n## Loading Models\n\nSimilar to [model loading during inference](../Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed.\n\n<details>\n\n<summary>Download and load models from remote sources</summary>\n\n> If we load models during inference through the following settings:\n> \n> ```python\n> model_configs=[\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n> ]\n> ```\n> \n> Then during training, fill in the following parameters to load the corresponding models:\n> \n> ```shell\n> --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\"\n> ```\n> \n> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).\n> \n> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.\n\n</details>\n\n<details>\n\n\n<summary>Load models from local file paths</summary>\n\n> If loading models from local files during inference, for example:\n> \n> ```python\n> model_configs=[\n>     ModelConfig([\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors\"\n>     ]),\n>     ModelConfig([\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n>     ]),\n>     ModelConfig(\"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors\")\n> ]\n> ```\n> \n> Then during training, set to:\n> \n> ```shell\n> --model_paths '[\n>     [\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors\"\n>     ],\n>     [\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n>     ],\n>     \"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors\"\n> ]' \\\n> ```\n> \n> Note that `--model_paths` is in JSON format, and extra `,` cannot appear in it, otherwise it cannot be parsed normally.\n\n</details>\n\n## Setting Trainable Modules\n\nThe training framework supports training of any model. Taking Qwen-Image as an example, to fully train the DiT model, set to:\n\n```shell\n--trainable_models \"dit\"\n```\n\nTo train LoRA of the DiT model, set to:\n\n```shell\n--lora_base_model dit --lora_target_modules \"to_q,to_k,to_v\" --lora_rank 32\n```\n\nWe hope to leave enough room for technical exploration, so the framework supports training any number of modules simultaneously. For example, to train the text encoder, controlnet, and LoRA of the DiT simultaneously:\n\n```shell\n--trainable_models \"text_encoder,controlnet\" --lora_base_model dit --lora_target_modules \"to_q,to_k,to_v\" --lora_rank 32\n```\n\nAdditionally, since the training script loads multiple modules (text encoder, dit, vae, etc.), prefixes need to be removed when saving model files. For example, when fully training the DiT part or training the LoRA model of the DiT part, please set `--remove_prefix_in_ckpt pipe.dit.`. If multiple modules are trained simultaneously, developers need to write code to split the state dict in the model file after training is completed.\n\n## Starting the Training Program\n\nThe training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Training commands are written in the following format:\n\n```shell\naccelerate launch xxx/train.py \\\n  --xxx yyy \\\n  --xxxx yyyy\n```\n\nWe have written preset training scripts for each model. See the documentation for each model for details.\n\nBy default, `accelerate` will train according to the configuration in `~/.cache/huggingface/accelerate/default_config.yaml`. Use `accelerate config` to configure interactively in the terminal, including multi-GPU training, [`DeepSpeed`](https://www.deepspeed.ai/), etc.\n\nWe provide recommended `accelerate` configuration files for some models, which can be set through `--config_file`. For example, full training of the Qwen-Image model:\n\n```shell\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n```\n\n## Training Considerations\n\n* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [../API_Reference/core/data.md#metadata](../API_Reference/core/data.md#metadata)\n* Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the `--save_steps` parameter to save model files at training step intervals.\n* When data volume * `dataset_repeat` exceeds $10^9$, we observed that the dataset speed becomes significantly slower, which seems to be a `PyTorch` bug. We are not sure if newer versions of `PyTorch` have fixed this issue.\n* For learning rate `--learning_rate`, it is recommended to set to `1e-4` in LoRA training and `1e-5` in full training.\n* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1)\n* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.\n* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.\n* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.\n\n## Low VRAM Training\n\nIf you want to complete LoRA model training on GPU with low vram, you can combine [Two-Stage Split Training](../Training/Split_Training.md) with `deepspeed_zero3_offload` training. First, split the preprocessing steps into the first stage and store the computed results onto the hard disk. Second, read these results from the disk and train the denoising model. By using `deepspeed_zero3_offload`, the training parameters and optimizer states are offloaded to the CPU or disk. We provide examples for some models, primarily by specifying the `deepspeed` configuration via `--config_file`.\n\nPlease note that the `deepspeed_zero3_offload` mode is incompatible with PyTorch's native gradient checkpointing mechanism. To address this, we have adapted the `checkpointing` interface of `deepspeed`. Users need to fill the `activation_checkpointing` field in the `deepspeed` configuration to enable gradient checkpointing.\n\nBelow is the script for low VRAM model training for the Qwen-Image model:\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:data_process\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n\naccelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:train\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --initialize_model_on_cpu\n```\n\nThe configurations for `accelerate` and `deepspeed` are as follows:\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: true\ndeepspeed_config:\n  deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json\n  zero3_init_flag: true\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\n```json\n{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": false,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 5e7,\n        \"stage3_prefetch_bucket_size\": 5e7,\n        \"stage3_param_persistence_threshold\": 1e5,\n        \"stage3_max_live_parameters\": 1e8,\n        \"stage3_max_reuse_distance\": 1e8,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    },\n    \"activation_checkpointing\": {\n        \"partition_activations\": false,\n        \"cpu_checkpointing\": false,\n        \"contiguous_memory_optimization\": false\n    },\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}\n```"
  },
  {
    "path": "docs/en/Pipeline_Usage/Setup.md",
    "content": "# Installing Dependencies\n\nInstall from source (recommended):\n\n```\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\nInstall from PyPI (there may be delays in version updates; for latest features, install from source):\n\n```\npip install diffsynth\n```\n\n## GPU/NPU Support\n\n* **NVIDIA GPU**\n\nInstall as described above.\n\n* **AMD GPU**\n\nYou need to install the `torch` package with ROCm support. Taking ROCm 6.4 (as of the article update date: December 15, 2025) on Linux as an example, run the following command:\n\n```shell\npip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4\n```\n\n* **Ascend NPU**\n\n1. Install [CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit) through official documentation.\n\n2. Install from source\n   ```shell\n   git clone https://github.com/modelscope/DiffSynth-Studio.git\n   cd DiffSynth-Studio\n   # aarch64/ARM\n   pip install -e .[npu_aarch64] \n   # x86\n   pip install -e .[npu] --extra-index-url \"https://download.pytorch.org/whl/cpu\"\n\nWhen using Ascend NPU, please replace `\"cuda\"` with `\"npu\"` in your Python code. For details, see [NPU Support](../Pipeline_Usage/GPU_support.md#ascend-npu).\n\n## Other Installation Issues\n\nIf you encounter issues during installation, they may be caused by upstream dependencies. Please refer to the documentation for these packages:\n\n* [torch](https://pytorch.org/get-started/locally/)\n* [Ascend/pytorch](https://github.com/Ascend/pytorch)\n* [sentencepiece](https://github.com/google/sentencepiece)\n* [cmake](https://cmake.org)\n"
  },
  {
    "path": "docs/en/Pipeline_Usage/VRAM_management.md",
    "content": "# VRAM Management\n\nVRAM management is a distinctive feature of `DiffSynth-Studio` that enables GPUs with low VRAM to run inference with large parameter models. This document uses Qwen-Image as an example to introduce how to use the VRAM management solution.\n\n## Basic Inference\n\nThe following code does not enable any VRAM management, occupying 56G VRAM as a reference.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## CPU Offload\n\nSince the model `Pipeline` consists of multiple components that are not called simultaneously, we can move some components to memory when they are not needed for computation, reducing VRAM usage. The following code implements this logic, occupying 40G VRAM.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## FP8 Quantization\n\nBuilding upon CPU Offload, we further enable FP8 quantization to reduce VRAM requirements. The following code allows model parameters to be stored in VRAM with FP8 precision and temporarily converted to BF16 precision for computation during inference, occupying 21G VRAM. However, this quantization scheme has minor image quality degradation issues.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n> Q: Why temporarily convert to BF16 precision during inference instead of computing with FP8 precision?\n> \n> A: Native FP8 computation is only supported on Hopper architecture GPUs (such as H20) and has significant computational errors. We currently do not enable FP8 precision computation. The current FP8 quantization only reduces VRAM usage but does not improve computation speed.\n\n## Dynamic VRAM Management\n\nIn CPU Offload, we control model components. In fact, we support Layer-level Offload, splitting a model into multiple Layers, keeping some resident in VRAM and storing others in memory for on-demand transfer to VRAM for computation. This feature requires model developers to provide detailed VRAM management solutions for each model. Related configurations are in `diffsynth/configs/vram_management_module_maps.py`.\n\nBy adding the `vram_limit` parameter to the `Pipeline`, the framework can automatically sense the remaining VRAM of the device and decide how to split the model between VRAM and memory. The smaller the `vram_limit`, the less VRAM occupied, but slower the speed.\n* When `vram_limit=None`, the default state, the framework assumes unlimited VRAM and dynamic VRAM management is disabled\n* When `vram_limit=10`, the framework will limit the model after VRAM usage exceeds 10G, moving the excess parts to memory storage\n* When `vram_limit=0`, the framework will do its best to reduce VRAM usage, storing all model parameters in memory and transferring them to VRAM for computation only when necessary\n\nWhen VRAM is insufficient to run model inference, the framework will attempt to exceed the `vram_limit` restriction to keep the model inference running. Therefore, the VRAM management framework cannot always guarantee that VRAM usage will be less than `vram_limit`. We recommend setting it to slightly less than the actual available VRAM. For example, when GPU VRAM is 16G, set it to `vram_limit=15.5`. In `PyTorch`, you can use `torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3)` to get the GPU's VRAM.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## Disk Offload\n\nIn more extreme cases, when memory is also insufficient to store the entire model, the Disk Offload feature allows lazy loading of model parameters, meaning each Layer of the model only reads the corresponding parameters from disk when the forward function is called. When enabling this feature, we recommend using high-speed SSD drives.\n\nDisk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=10,\n)\nprompt = \"Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal.\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## More Usage Methods\n\nInformation in `vram_config` can be filled in manually, for example, Disk Offload without FP8 quantization:\n\n```python\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n```\n\nSpecifically, the VRAM management module divides model Layers into the following four states:\n\n* Offload: This model will not be called in the short term. This state is controlled by switching `Pipeline`\n* Onload: This model will be called at any time soon. This state is controlled by switching `Pipeline`\n* Preparing: Intermediate state between Onload and Computation. A temporary storage state when VRAM allows. This state is controlled by the VRAM management mechanism and enters this state if and only if [vram_limit is set to unlimited] or [vram_limit is set and there is spare VRAM]\n* Computation: The model is being computed. This state is controlled by the VRAM management mechanism and is temporarily entered only during `forward`\n\nIf you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md).\n\n## Best Practices\n\n* Sufficient VRAM -> Use [Basic Inference](#basic-inference)\n* Insufficient VRAM\n    * Sufficient memory -> Use [Dynamic VRAM Management](#dynamic-vram-management)\n    * Insufficient memory -> Use [Disk Offload](#disk-offload)"
  },
  {
    "path": "docs/en/QA.md",
    "content": "# Frequently Asked Questions\n\n## Why doesn't the training framework support batch size > 1?\n\n* **Larger batch sizes no longer achieve significant acceleration**: Due to acceleration technologies such as flash attention that have fully improved GPU utilization, larger batch sizes will only bring greater VRAM usage without significant acceleration. The experience with small models like Stable Diffusion 1.5 is no longer applicable to the latest large models.\n* **Larger batch sizes can be achieved through other solutions**: Multi-GPU training and Gradient Accumulation can both mathematically equivalently achieve larger batch sizes.\n* **Larger batch sizes contradict the framework's general design**: We hope to build a general training framework. Many models cannot accommodate larger batch sizes, such as text encodings of different lengths and images of different resolutions, which cannot be merged into larger batches.\n\n## Why aren't redundant parameters removed from certain models?\n\nIn some models, redundant parameters exist. For example, in Qwen-Image's DiT model, the text portion of the last layer does not participate in any calculations. This is a minor bug left by the model developers. Setting it as trainable directly will also cause errors in multi-GPU training.\n\nTo maintain compatibility with other models in the open-source community, we have decided to retain these parameters. These redundant parameters can avoid errors in multi-GPU training through the `--find_unused_parameters` parameter.\n\n## Why does FP8 quantization show no acceleration effect?\n\nNative FP8 computation relies on Hopper architecture GPUs and has significant precision errors. It is currently immature technology, so this project does not support native FP8 computation.\n\nFP8 computation in VRAM management refers to storing model parameters in memory or VRAM with FP8 precision and temporarily converting them to other precisions when needed for computation. Therefore, it can only reduce VRAM usage without acceleration effects.\n\n## Why doesn't the training framework support native FP8 precision training?\n\nEven with suitable hardware conditions, we currently have no plans to support native FP8 precision training.\n\n* The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present.\n* Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8.\n\nTherefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community.\n\n## How to dynamically load LoRA models during inference?\n\nWe support two loading methods for LoRA models. See [LoRA Loading](./Pipeline_Usage/Model_Inference.md#loading-lora) for details:\n\n* Cold Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is not enabled for the base model, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, and LoRA cannot be unloaded after loading.\n* Hot Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is enabled for the base model, LoRA will not be fused into the base model weights. In this case, inference speed will slow down, and LoRA can be unloaded after loading via `pipe.clear_lora()`.\n"
  },
  {
    "path": "docs/en/README.md",
    "content": "# DiffSynth-Studio Documentation\n\nWelcome to the magical world of Diffusion models! `DiffSynth-Studio` is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We aim to build a universal Diffusion model framework that fosters technological innovation through framework construction, aggregates the power of the open-source community, and explores the boundaries of generative model technology!\n\n<details>\n\n<summary>Documentation Reading Guide</summary>\n\n```mermaid\ngraph LR;\n    I_want_to_use_models_for_inference_and_training-->sec1[Section 1: Getting Started];\n    I_want_to_use_models_for_inference_and_training-->sec2[Section 2: Model Details];\n    I_want_to_use_models_for_inference_and_training-->sec3[Section 3: Training Framework];\n    I_want_to_develop_based_on_this_framework-->sec3[Section 3: Training Framework];\n    I_want_to_develop_based_on_this_framework-->sec4[Section 4: Model Integration];\n    I_want_to_develop_based_on_this_framework-->sec5[Section 5: API Reference];\n    I_want_to_explore_new_technologies_based_on_this_project-->sec4[Section 4: Model Integration];\n    I_want_to_explore_new_technologies_based_on_this_project-->sec5[Section 5: API Reference];\n    I_want_to_explore_new_technologies_based_on_this_project-->sec6[Section 6: Academic Guide];\n    I_encountered_a_problem-->sec7[Section 7: Frequently Asked Questions];\n```\n\n</details>\n\n## Section 1: Getting Started\n\nThis section introduces the basic usage of `DiffSynth-Studio`, including how to enable VRAM management for inference on GPUs with extremely low VRAM, and how to train various base models, LoRAs, ControlNets, and other models.\n\n* [Installation Dependencies](./Pipeline_Usage/Setup.md)\n* [Model Inference](./Pipeline_Usage/Model_Inference.md)\n* [VRAM Management](./Pipeline_Usage/VRAM_management.md)\n* [Model Training](./Pipeline_Usage/Model_Training.md)\n* [Environment Variables](./Pipeline_Usage/Environment_Variables.md)\n* [GPU/NPU Support](./Pipeline_Usage/GPU_support.md)\n\n## Section 2: Model Details\n\nThis section introduces the Diffusion models supported by `DiffSynth-Studio`. Some model pipelines feature special functionalities such as controllable generation and parallel acceleration.\n\n* [FLUX.1](./Model_Details/FLUX.md)\n* [Wan](./Model_Details/Wan.md)\n* [Qwen-Image](./Model_Details/Qwen-Image.md)\n* [FLUX.2](./Model_Details/FLUX2.md)\n* [Z-Image](./Model_Details/Z-Image.md)\n* [Anima](./Model_Details/Anima.md)\n* [LTX-2](./Model_Details/LTX-2.md)\n\n## Section 3: Training Framework\n\nThis section introduces the design philosophy of the training framework in `DiffSynth-Studio`, helping developers understand the principles of Diffusion model training algorithms.\n\n* [Basic Principles of Diffusion Models](./Training/Understanding_Diffusion_models.md)\n* [Standard Supervised Training](./Training/Supervised_Fine_Tuning.md)\n* [Enabling FP8 Precision in Training](./Training/FP8_Precision.md)\n* [End-to-End Distillation Accelerated Training](./Training/Direct_Distill.md)\n* [Two-Stage Split Training](./Training/Split_Training.md)\n* [Differential LoRA Training](./Training/Differential_LoRA.md)\n\n## Section 4: Model Integration\n\nThis section introduces how to integrate models into `DiffSynth-Studio` to utilize the framework's basic functions, helping developers provide support for new models in this project or perform inference and training of private models.\n\n* [Integrating Model Architecture](./Developer_Guide/Integrating_Your_Model.md)\n* [Building a Pipeline](./Developer_Guide/Building_a_Pipeline.md)\n* [Enabling Fine-Grained VRAM Management](./Developer_Guide/Enabling_VRAM_management.md)\n* [Model Training Integration](./Developer_Guide/Training_Diffusion_Models.md)\n\n## Section 5: API Reference\n\nThis section introduces the independent core module `diffsynth.core` in `DiffSynth-Studio`, explaining how internal functions are designed and operate. Developers can use these functional modules in other codebase developments if needed.\n\n* [`diffsynth.core.attention`](./API_Reference/core/attention.md): Attention mechanism implementation\n* [`diffsynth.core.data`](./API_Reference/core/data.md): Data processing operators and general datasets\n* [`diffsynth.core.gradient`](./API_Reference/core/gradient.md): Gradient checkpointing\n* [`diffsynth.core.loader`](./API_Reference/core/loader.md): Model download and loading\n* [`diffsynth.core.vram`](./API_Reference/core/vram.md): VRAM management\n\n## Section 6: Academic Guide\n\nThis section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.\n\n* [Training models from scratch](./Research_Tutorial/train_from_scratch.md)\n* [Inference improvement techniques](./Research_Tutorial/inference_time_scaling.md)\n* Designing controllable generation models 【coming soon】\n* Creating new training paradigms 【coming soon】\n\n## Section 7: Frequently Asked Questions\n\nThis section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub.\n\n* [Frequently Asked Questions](./QA.md)"
  },
  {
    "path": "docs/en/Research_Tutorial/inference_time_scaling.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8db54992\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Inference Optimization Techniques\\n\",\n    \"\\n\",\n    \"DiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0911cad4\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 1. Image Quality Quantification\\n\",\n    \"\\n\",\n    \"First, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4faca4ca\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from modelscope import AutoProcessor, AutoModel\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"class PickScore(torch.nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.processor = AutoProcessor.from_pretrained(\\\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\\\")\\n\",\n    \"        self.model = AutoModel.from_pretrained(\\\"AI-ModelScope/PickScore_v1\\\").eval().to(\\\"cuda\\\")\\n\",\n    \"\\n\",\n    \"    def forward(self, image, prompt):\\n\",\n    \"        image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\\\"pt\\\").to(\\\"cuda\\\")\\n\",\n    \"        text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\\\"pt\\\").to(\\\"cuda\\\")\\n\",\n    \"        with torch.inference_mode():\\n\",\n    \"            image_embs = self.model.get_image_features(**image_inputs).pooler_output\\n\",\n    \"            image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\\n\",\n    \"            text_embs = self.model.get_text_features(**text_inputs).pooler_output\\n\",\n    \"            text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\\n\",\n    \"            score = (text_embs @ image_embs.T).flatten().item()\\n\",\n    \"        return score\\n\",\n    \"\\n\",\n    \"reward_model = PickScore()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5f807cec\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Inference-time Scaling Techniques\\n\",\n    \"\\n\",\n    \"Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use \\\"thinking mode\\\" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.\\n\",\n    \"\\n\",\n    \"> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).\\n\",\n    \"\\n\",\n    \"Run the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c5818a87\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\\n\",\n    \"\\n\",\n    \"pipe = Flux2ImagePipeline.from_pretrained(\\n\",\n    \"    torch_dtype=torch.bfloat16,\\n\",\n    \"    device=\\\"cuda\\\",\\n\",\n    \"    model_configs=[\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"text_encoder/*.safetensors\\\"),\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"transformer/*.safetensors\\\"),\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"vae/diffusion_pytorch_model.safetensors\\\"),\\n\",\n    \"    ],\\n\",\n    \"    tokenizer_config=ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"tokenizer/\\\"),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f58e9945\",\n   \"metadata\": {},\n   \"source\": [\n    \"Generate a sketch cat image using the prompt `\\\"sketch, a cat\\\"` and score it with the PickScore model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6ea2d258\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def evaluate_noise(noise, pipe, reward_model, prompt):\\n\",\n    \"    # Generate an image and compute the score.\\n\",\n    \"    image = pipe(\\n\",\n    \"        prompt=prompt,\\n\",\n    \"        num_inference_steps=4,\\n\",\n    \"        initial_noise=noise,\\n\",\n    \"        progress_bar_cmd=lambda x: x,\\n\",\n    \"    )\\n\",\n    \"    score = reward_model(image, prompt)\\n\",\n    \"    return score\\n\",\n    \"\\n\",\n    \"torch.manual_seed(1)\\n\",\n    \"prompt = \\\"sketch, a cat\\\"\\n\",\n    \"noise = pipe.generate_noise((1, 128, 64, 64), rand_device=\\\"cuda\\\", rand_torch_dtype=pipe.torch_dtype)\\n\",\n    \"\\n\",\n    \"image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_1, prompt))\\n\",\n    \"image_1\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5e11694e\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.1 Best-of-N Random Search\\n\",\n    \"\\n\",\n    \"Model generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"241f10d2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tqdm import tqdm\\n\",\n    \"\\n\",\n    \"def random_search(base_latents, objective_reward_fn, total_eval_budget):\\n\",\n    \"    # Search for the noise randomly.\\n\",\n    \"    best_noise = base_latents\\n\",\n    \"    best_score = objective_reward_fn(base_latents)\\n\",\n    \"    for it in tqdm(range(total_eval_budget - 1)):\\n\",\n    \"        noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\\n\",\n    \"        score = objective_reward_fn(noise)\\n\",\n    \"        if score > best_score:\\n\",\n    \"            best_score, best_noise = score, noise\\n\",\n    \"    return best_noise\\n\",\n    \"\\n\",\n    \"best_noise = random_search(\\n\",\n    \"    base_latents=noise,\\n\",\n    \"    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\\n\",\n    \"    total_eval_budget=50,\\n\",\n    \")\\n\",\n    \"image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_2, prompt))\\n\",\n    \"image_2\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8e9bf966\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c9578349\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.2 SES Search\\n\",\n    \"\\n\",\n    \"To overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).\\n\",\n    \"\\n\",\n    \"Image generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.\\n\",\n    \"\\n\",\n    \"Run the following code to perform efficient best Gaussian noise matrix search using SES.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"adeed2aa\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from diffsynth.utils.ses import ses_search\\n\",\n    \"\\n\",\n    \"best_noise = ses_search(\\n\",\n    \"    base_latents=noise,\\n\",\n    \"    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\\n\",\n    \"    total_eval_budget=50,\\n\",\n    \")\\n\",\n    \"image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_3, prompt))\\n\",\n    \"image_3\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"940a97f1\",\n   \"metadata\": {},\n   \"source\": [\n    \"Observing the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The \\\"sketch cat\\\" demonstrates more refined overall composition and more layered contrast between light and shadow.\\n\",\n    \"\\n\",\n    \"Inference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"dzj8\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.19\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/en/Research_Tutorial/inference_time_scaling.md",
    "content": "# Inference Optimization Techniques\n\nDiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example.\n\nNotebook: https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/inference_time_scaling.ipynb\n\n## 1. Image Quality Quantification\n\nFirst, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1).\n\n```python\nfrom modelscope import AutoProcessor, AutoModel\nimport torch\n\nclass PickScore(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.processor = AutoProcessor.from_pretrained(\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\")\n        self.model = AutoModel.from_pretrained(\"AI-ModelScope/PickScore_v1\").eval().to(\"cuda\")\n\n    def forward(self, image, prompt):\n        image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n        text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n        with torch.inference_mode():\n            image_embs = self.model.get_image_features(**image_inputs).pooler_output\n            image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n            text_embs = self.model.get_text_features(**text_inputs).pooler_output\n            text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n            score = (text_embs @ image_embs.T).flatten().item()\n        return score\n\nreward_model = PickScore()\n```\n\n## 2. Inference-time Scaling Techniques\n\nInference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use \"thinking mode\" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.\n\n> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).\n\nRun the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model.\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\n```\n\nGenerate a sketch cat image using the prompt `\"sketch, a cat\"` and score it with the PickScore model.\n\n```python\ndef evaluate_noise(noise, pipe, reward_model, prompt):\n    # Generate an image and compute the score.\n    image = pipe(\n        prompt=prompt,\n        num_inference_steps=4,\n        initial_noise=noise,\n        progress_bar_cmd=lambda x: x,\n    )\n    score = reward_model(image, prompt)\n    return score\n\ntorch.manual_seed(1)\nprompt = \"sketch, a cat\"\nnoise = pipe.generate_noise((1, 128, 64, 64), rand_device=\"cuda\", rand_torch_dtype=pipe.torch_dtype)\n\nimage_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\nprint(\"Score:\", reward_model(image_1, prompt))\nimage_1\n```\n\n![Image](https://github.com/user-attachments/assets/b6546c6d-b368-4463-b703-d561a9134ba0)\n\n### 2.1 Best-of-N Random Search\n\nModel generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image.\n\n```python\nfrom tqdm import tqdm\n\ndef random_search(base_latents, objective_reward_fn, total_eval_budget):\n    # Search for the noise randomly.\n    best_noise = base_latents\n    best_score = objective_reward_fn(base_latents)\n    for it in tqdm(range(total_eval_budget - 1)):\n        noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\n        score = objective_reward_fn(noise)\n        if score > best_score:\n            best_score, best_noise = score, noise\n    return best_noise\n\nbest_noise = random_search(\n    base_latents=noise,\n    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n    total_eval_budget=50,\n)\nimage_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\nprint(\"Score:\", reward_model(image_2, prompt))\nimage_2\n```\n\n![Image](https://github.com/user-attachments/assets/b8dba70a-daa8-4368-8f32-a6c150daecb5)\n\nWe can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget.\n\n### 2.2 SES Search\n\nTo overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).\n\nImage generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.\n\nRun the following code to perform efficient best Gaussian noise matrix search using SES.\n\n```python\nfrom diffsynth.utils.ses import ses_search\n\nbest_noise = ses_search(\n    base_latents=noise,\n    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n    total_eval_budget=50,\n)\nimage_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\nprint(\"Score:\", reward_model(image_3, prompt))\nimage_3\n```\n\n![Image](https://github.com/user-attachments/assets/9a3f7598-3812-46d2-b333-cd65e49886ab)\n\nObserving the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The \"sketch cat\" demonstrates more refined overall composition and more layered contrast between light and shadow.\n\nInference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction.\n"
  },
  {
    "path": "docs/en/Research_Tutorial/train_from_scratch.md",
    "content": "# Training Models from Scratch\n\nDiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.\n\n## 1. Building Model Architecture\n\n### 1.1 Diffusion Model\n\nFrom UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:\n\n* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise\n* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder\n* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at\n\nThe model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.\n\n<details>\n<summary>Model Architecture Code</summary>\n\n```python\nimport torch, accelerate\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat\n\nfrom transformers import AutoProcessor, AutoTokenizer\nfrom diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model\nfrom diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task\nfrom diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom diffsynth.models.general_modules import TimestepEmbeddings\nfrom diffsynth.models.z_image_text_encoder import ZImageTextEncoder\nfrom diffsynth.models.flux2_vae import Flux2VAE\n\n\nclass AAAPositionalEmbedding(torch.nn.Module):\n    def __init__(self, height=16, width=16, dim=1024):\n        super().__init__()\n        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))\n        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))\n\n    def forward(self, image, text):\n        height, width = image.shape[-2:]\n        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)\n        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode=\"bilinear\")\n        image_emb = rearrange(image_emb, \"B C H W -> B (H W) C\")\n        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)\n        text_emb = repeat(text_emb, \"C -> B L C\", B=text.shape[0], L=text.shape[1])\n        emb = torch.concat([image_emb, text_emb], dim=1)\n        return emb\n\n\nclass AAABlock(torch.nn.Module):\n    def __init__(self, dim=1024, num_heads=32):\n        super().__init__()\n        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.to_q = torch.nn.Linear(dim, dim)\n        self.to_k = torch.nn.Linear(dim, dim)\n        self.to_v = torch.nn.Linear(dim, dim)\n        self.to_out = torch.nn.Linear(dim, dim)\n        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.ff = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*3),\n            torch.nn.SiLU(),\n            torch.nn.Linear(dim*3, dim),\n        )\n        self.to_gate = torch.nn.Linear(dim, dim * 2)\n        self.num_heads = num_heads\n\n    def attention(self, emb, pos_emb):\n        emb = self.norm_attn(emb + pos_emb)\n        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)\n        emb = attention_forward(\n            q, k, v,\n            q_pattern=\"b s (n d)\", k_pattern=\"b s (n d)\", v_pattern=\"b s (n d)\", out_pattern=\"b s (n d)\",\n            dims={\"n\": self.num_heads},\n        )\n        emb = self.to_out(emb)\n        return emb\n    \n    def feed_forward(self, emb, pos_emb):\n        emb = self.norm_mlp(emb + pos_emb)\n        emb = self.ff(emb)\n        return emb\n    \n    def forward(self, emb, pos_emb, t_emb):\n        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)\n        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)\n        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)\n        return emb\n\n\nclass AAADiT(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.pos_embedder = AAAPositionalEmbedding(dim=dim)\n        self.timestep_embedder = TimestepEmbeddings(256, dim)\n        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))\n        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))\n        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])\n        self.proj_out = torch.nn.Linear(dim, 128)\n\n    def forward(\n        self,\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        pos_emb = self.pos_embedder(latents, prompt_embeds)\n        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)\n        image = self.image_embedder(rearrange(latents, \"B C H W -> B (H W) C\"))\n        text = self.text_embedder(prompt_embeds)\n        emb = torch.concat([image, text], dim=1)\n        for block_id, block in enumerate(self.blocks):\n            emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                emb=emb,\n                pos_emb=pos_emb,\n                t_emb=t_emb,\n            )\n        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]\n        emb = self.proj_out(emb)\n        emb = rearrange(emb, \"B (H W) C -> B C H W\", W=latents.shape[-1])\n        return emb\n```\n\n</details>\n\n### 1.2 Encoder-Decoder Models\n\nBesides the Diffusion model used for denoising, we also need two other models:\n\n* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.\n* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).\n\nThe architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py), so we don't need to modify any code.\n\n## 2. Building Pipeline\n\nWe introduced how to build a model Pipeline in the document [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.\n\n<details>\n<summary>Pipeline Code</summary>\n\n```python\nclass AAAImagePipeline(BasePipeline):\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.2\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: AAADiT = None\n        self.vae: Flux2VAE = None\n        self.tokenizer: AutoProcessor = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            AAAUnit_PromptEmbedder(),\n            AAAUnit_NoiseInitializer(),\n            AAAUnit_InputImageEmbedder(),\n        ]\n        self.model_fn = model_fn_aaa\n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = None,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"aaa_dit\")\n        pipe.vae = model_pool.fetch_model(\"flux2_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)\n\n        # Parameters\n        inputs_posi = {\"prompt\": prompt}\n        inputs_nega = {\"negative_prompt\": negative_prompt}\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"])\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass AAAUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_embeds\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n        self.hidden_states_layers = (-1,)\n\n    def process(self, pipe: AAAImagePipeline, prompt):\n        pipe.load_models_to_device(self.onload_model_names)\n        text = pipe.tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            tokenize=False,\n            add_generation_prompt=True,\n            enable_thinking=False,\n        )\n        inputs = pipe.tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128).to(pipe.device)\n        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)\n        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)\n        return {\"prompt_embeds\": prompt_embeds}\n\n\nclass AAAUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\nclass AAAUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: AAAImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae.encode(image)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\ndef model_fn_aaa(\n    dit: AAADiT,\n    latents=None,\n    prompt_embeds=None,\n    timestep=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    model_output = dit(\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    return model_output\n```\n\n</details>\n\n## 3. Preparing Dataset\n\nTo quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](../Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](../API_Reference/core/data.md).\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data\n```\n\n### 4. Start Training\n\nThe training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.\n\nTo enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.\n\nThis training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.\n\n<details>\n<summary>Training Code</summary>\n\n```python\nclass AAATrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        super().__init__()\n        self.pipe = AAAImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n                ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n        )\n        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)\n        self.pipe.freeze_except([\"dit\"])\n        self.pipe.scheduler.set_timesteps(1000, training=True)\n\n    def forward(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            \"cfg_scale\": 1,\n            \"use_gradient_checkpointing\": False,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n\n\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)\n    dataset = UnifiedDataset(\n        base_path=\"data/images\",\n        metadata_path=\"data/metadata_merged.csv\",\n        max_data_items=10000000,\n        data_file_keys=(\"image\",),\n        main_data_operator=UnifiedDataset.default_image_operator(base_path=\"data/images\", height=256, width=256)\n    )\n    model = AAATrainingModule(device=accelerator.device)\n    model_logger = ModelLogger(\n        \"models/AAA/v1\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=2e-4,\n        num_workers=4,\n        save_steps=50000,\n        num_epochs=999999,\n    )\n```\n\n</details>\n\n## 5. Verifying Training Results\n\nIf you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).\n\n```shell\nmodelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel\n```\n\nLoading the model\n\n```python\nfrom diffsynth import load_model\n\npipe = AAAImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n)\npipe.dit = load_model(AAADiT, \"models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors\", torch_dtype=torch.bfloat16, device=\"cuda\")\n```\n\nModel inference, generating the first-generation Pokemon \"starter trio\". At this point, the images generated by the model basically match the training data.\n\n```python\nfor seed, prompt in enumerate([\n    \"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws\",\n    \"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws\",\n    \"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail\",\n]):\n    image = pipe(\n        prompt=prompt,\n        negative_prompt=\" \",\n        num_inference_steps=30,\n        cfg_scale=10,\n        seed=seed,\n        height=256, width=256,\n    )\n    image.save(f\"image_{seed}.jpg\")\n```\n\n|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|\n|-|-|-|\n\nModel inference, generating Pokemon with \"sharp claws\". At this point, different random seeds can produce different image results.\n\n```python\nfor seed, prompt in enumerate([\n    \"sharp claws\",\n    \"sharp claws\",\n    \"sharp claws\",\n]):\n    image = pipe(\n        prompt=prompt,\n        negative_prompt=\" \",\n        num_inference_steps=30,\n        cfg_scale=10,\n        seed=seed+4,\n        height=256, width=256,\n    )\n    image.save(f\"image_sharp_claws_{seed}.jpg\")\n```\n\n|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|\n|-|-|-|\n\nNow, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!"
  },
  {
    "path": "docs/en/Research_Tutorial/train_from_scratch.py",
    "content": "import torch, accelerate\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat\n\nfrom transformers import AutoProcessor, AutoTokenizer\nfrom diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model\nfrom diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task\nfrom diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom diffsynth.models.general_modules import TimestepEmbeddings\nfrom diffsynth.models.z_image_text_encoder import ZImageTextEncoder\nfrom diffsynth.models.flux2_vae import Flux2VAE\n\n\nclass AAAPositionalEmbedding(torch.nn.Module):\n    def __init__(self, height=16, width=16, dim=1024):\n        super().__init__()\n        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))\n        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))\n\n    def forward(self, image, text):\n        height, width = image.shape[-2:]\n        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)\n        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode=\"bilinear\")\n        image_emb = rearrange(image_emb, \"B C H W -> B (H W) C\")\n        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)\n        text_emb = repeat(text_emb, \"C -> B L C\", B=text.shape[0], L=text.shape[1])\n        emb = torch.concat([image_emb, text_emb], dim=1)\n        return emb\n\n\nclass AAABlock(torch.nn.Module):\n    def __init__(self, dim=1024, num_heads=32):\n        super().__init__()\n        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.to_q = torch.nn.Linear(dim, dim)\n        self.to_k = torch.nn.Linear(dim, dim)\n        self.to_v = torch.nn.Linear(dim, dim)\n        self.to_out = torch.nn.Linear(dim, dim)\n        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.ff = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*3),\n            torch.nn.SiLU(),\n            torch.nn.Linear(dim*3, dim),\n        )\n        self.to_gate = torch.nn.Linear(dim, dim * 2)\n        self.num_heads = num_heads\n\n    def attention(self, emb, pos_emb):\n        emb = self.norm_attn(emb + pos_emb)\n        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)\n        emb = attention_forward(\n            q, k, v,\n            q_pattern=\"b s (n d)\", k_pattern=\"b s (n d)\", v_pattern=\"b s (n d)\", out_pattern=\"b s (n d)\",\n            dims={\"n\": self.num_heads},\n        )\n        emb = self.to_out(emb)\n        return emb\n    \n    def feed_forward(self, emb, pos_emb):\n        emb = self.norm_mlp(emb + pos_emb)\n        emb = self.ff(emb)\n        return emb\n    \n    def forward(self, emb, pos_emb, t_emb):\n        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)\n        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)\n        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)\n        return emb\n\n\nclass AAADiT(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.pos_embedder = AAAPositionalEmbedding(dim=dim)\n        self.timestep_embedder = TimestepEmbeddings(256, dim)\n        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))\n        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))\n        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])\n        self.proj_out = torch.nn.Linear(dim, 128)\n\n    def forward(\n        self,\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        pos_emb = self.pos_embedder(latents, prompt_embeds)\n        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)\n        image = self.image_embedder(rearrange(latents, \"B C H W -> B (H W) C\"))\n        text = self.text_embedder(prompt_embeds)\n        emb = torch.concat([image, text], dim=1)\n        for block_id, block in enumerate(self.blocks):\n            emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                emb=emb,\n                pos_emb=pos_emb,\n                t_emb=t_emb,\n            )\n        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]\n        emb = self.proj_out(emb)\n        emb = rearrange(emb, \"B (H W) C -> B C H W\", W=latents.shape[-1])\n        return emb\n\n\nclass AAAImagePipeline(BasePipeline):\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.2\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: AAADiT = None\n        self.vae: Flux2VAE = None\n        self.tokenizer: AutoProcessor = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            AAAUnit_PromptEmbedder(),\n            AAAUnit_NoiseInitializer(),\n            AAAUnit_InputImageEmbedder(),\n        ]\n        self.model_fn = model_fn_aaa\n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = None,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"aaa_dit\")\n        pipe.vae = model_pool.fetch_model(\"flux2_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)\n\n        # Parameters\n        inputs_posi = {\"prompt\": prompt}\n        inputs_nega = {\"negative_prompt\": negative_prompt}\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"])\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass AAAUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_embeds\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n        self.hidden_states_layers = (-1,)\n\n    def process(self, pipe: AAAImagePipeline, prompt):\n        pipe.load_models_to_device(self.onload_model_names)\n        text = pipe.tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            tokenize=False,\n            add_generation_prompt=True,\n            enable_thinking=False,\n        )\n        inputs = pipe.tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128).to(pipe.device)\n        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)\n        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)\n        return {\"prompt_embeds\": prompt_embeds}\n\n\nclass AAAUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\nclass AAAUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: AAAImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae.encode(image)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\ndef model_fn_aaa(\n    dit: AAADiT,\n    latents=None,\n    prompt_embeds=None,\n    timestep=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    model_output = dit(\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    return model_output\n\n\nclass AAATrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        super().__init__()\n        self.pipe = AAAImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n                ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n        )\n        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)\n        self.pipe.freeze_except([\"dit\"])\n        self.pipe.scheduler.set_timesteps(1000, training=True)\n\n    def forward(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            \"cfg_scale\": 1,\n            \"use_gradient_checkpointing\": False,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n\n\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)\n    dataset = UnifiedDataset(\n        base_path=\"data/images\",\n        metadata_path=\"data/metadata_merged.csv\",\n        max_data_items=10000000,\n        data_file_keys=(\"image\",),\n        main_data_operator=UnifiedDataset.default_image_operator(base_path=\"data/images\", height=256, width=256)\n    )\n    model = AAATrainingModule(device=accelerator.device)\n    model_logger = ModelLogger(\n        \"models/AAA/v1\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=2e-4,\n        num_workers=4,\n        save_steps=50000,\n        num_epochs=999999,\n    )"
  },
  {
    "path": "docs/en/Training/Differential_LoRA.md",
    "content": "# Differential LoRA Training\n\nDifferential LoRA training is a special form of LoRA training designed to enable models to learn differences between images.\n\n## Training Approach\n\nWe were unable to identify the original proposer of differential LoRA training, as this technique has been circulating in the open-source community for a long time.\n\nAssume we have two similar-content images: Image 1 and Image 2. For example, both images contain a car, but Image 1 has fewer details while Image 2 has more details. In differential LoRA training, we perform two-step training:\n\n* Train LoRA 1 using Image 1 as training data with [standard supervised training](../Training/Supervised_Fine_Tuning.md)\n* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](../Training/Supervised_Fine_Tuning.md)\n\nIn the first training step, since there is only one training image, the LoRA model easily overfits. Therefore, after training, LoRA 1 will cause the model to generate Image 1 without hesitation, regardless of the random seed. In the second training step, the LoRA model overfits again. Thus, after training, with the combined effect of LoRA 1 and LoRA 2, the model will generate Image 2 without hesitation. In short:\n\n* LoRA 1 = Generate Image 1\n* LoRA 1 + LoRA 2 = Generate Image 2\n\nAt this point, discarding LoRA 1 and using only LoRA 2, the model will understand the difference between Image 1 and Image 2, making the generated content tend toward \"less like Image 1, more like Image 2.\"\n\nSingle training data can ensure the model overfits to the training data, but lacks stability. To improve stability, we can train with multiple image pairs and average the trained LoRA 2 models to obtain a more stable LoRA.\n\nUsing this training approach, some functionally unique LoRA models can be trained. For example, using ugly and beautiful image pairs to train LoRAs that enhance image aesthetics; using low-detail and high-detail image pairs to train LoRAs that increase image detail.\n\n## Model Effects\n\nWe have trained several aesthetic enhancement LoRAs using differential LoRA training techniques. You can visit the corresponding model pages to view the generation effects.\n\n* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1)\n* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)\n\n## Using Differential LoRA Training in the Training Framework\n\nThe first step of training is identical to ordinary LoRA training. In the second step's training command, fill in the path of the first step's LoRA model file through the `--preset_lora_path` parameter, and set `--preset_lora_model` to the same parameters as `lora_base_model` to load LoRA 1 into the base model.\n\n## Framework Design Concept\n\nIn the training framework, the model pointed to by `--preset_lora_path` is loaded in the `switch_pipe_to_training_mode` of `DiffusionTrainingModule`."
  },
  {
    "path": "docs/en/Training/Direct_Distill.md",
    "content": "# End-to-End Distillation Accelerated Training\n\n## Distillation Accelerated Training\n\nThe inference process of Diffusion models typically requires multi-step iterations, which improves generation quality but also makes the generation process slow. Through distillation accelerated training, the number of steps required to generate clear content can be reduced. The essence of distillation accelerated training technology is to align the generation effects of a small number of steps with those of a large number of steps.\n\nThere are diverse methods for distillation accelerated training, such as:\n\n* Adversarial training ADD (Adversarial Diffusion Distillation)\n    * Paper: https://arxiv.org/abs/2311.17042\n    * Model: [stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo)\n* Progressive training Hyper-SD\n    * Paper: https://arxiv.org/abs/2404.13686\n    * Model: [ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD)\n\n## Direct Distillation\n\nAt the framework level, supporting these distillation accelerated training schemes is extremely difficult. In the design of the training framework, we need to ensure that the training scheme meets the following conditions:\n\n* Generality: The training scheme applies to most Diffusion models supported within the framework, rather than only working for a specific model, which is a basic requirement for code framework construction.\n* Stability: The training scheme must ensure stable training effects without requiring manual fine-tuning of parameters. Adversarial training in ADD cannot guarantee stability.\n* Simplicity: The training scheme does not introduce additional complex modules. According to Occam's Razor principle, complex solutions may introduce potential risks. The Human Feedback Learning in Hyper-SD makes the training process overly complex.\n\nTherefore, in the training framework of `DiffSynth-Studio`, we designed an end-to-end distillation accelerated training scheme, which we call Direct Distillation. The pseudocode for the training process is as follows:\n\n```\nseed = xxx\nwith torch.no_grad():\n    image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)\nimage_2 = pipe(prompt, steps=4, seed=seed, cfg=1)\nloss = torch.nn.functional.mse_loss(image_1, image_2)\n```\n\nYes, it's a very end-to-end training scheme that produces immediate results with minimal training.\n\n## Models Trained with Direct Distillation\n\nWe trained two models based on Qwen-Image using this scheme:\n\n* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation training\n* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation training\n\nClick on the model links to go to the model pages and view the model effects.\n\n## Using Distillation Accelerated Training in the Training Framework\n\nFirst, you need to generate training data. Please refer to the [Model Inference](../Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps.\n\nTaking Qwen-Image as an example, the following code can generate an image:\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\nThen, we compile the necessary information into [metadata files](../API_Reference/core/data.md#metadata):\n\n```csv\nimage,prompt,seed,rand_device,num_inference_steps,cfg_scale\ndistill_qwen/image.jpg,\"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\",0,cpu,4,1\n```\n\nThis sample dataset can be downloaded directly:\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\nThen start LoRA distillation accelerated training:\n\n```shell\nbash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh\n```\n\nPlease note that in the [training script parameters](../Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image.\n\n## Framework Design Concept\n\nCompared to [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`.\n\n## Future Work\n\nDirect Distillation is a highly general acceleration scheme, but it may not be the best-performing scheme. Therefore, we have not yet published this technology in paper form. We hope to leave this problem to the academic and open-source communities to solve together, and we look forward to developers providing more complete general training schemes."
  },
  {
    "path": "docs/en/Training/FP8_Precision.md",
    "content": "# Enabling FP8 Precision in Training\n\nAlthough `DiffSynth-Studio` supports [VRAM management](../Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes.\n\nFP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](../QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision.\n\n## Enabling FP8\n\nIn our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py).\n\nPlease note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types:\n\n* Parameters are not trainable, such as VAE models\n* Gradients do not update their parameters, such as DiT models in LoRA training\n\nExperimental verification shows that LoRA training with FP8 enabled does not cause significant image quality degradation. However, theoretical errors do exist. If you encounter training results inferior to BF16 precision training when using this feature, please provide feedback through GitHub issues.\n\n## Training Framework Design Concept\n\nThe training framework completely reuses the inference VRAM management, and only parses VRAM management configurations through `parse_model_configs` in `DiffusionTrainingModule` during training."
  },
  {
    "path": "docs/en/Training/Split_Training.md",
    "content": "# Two-Stage Split Training\n\nThis document introduces split training, which can automatically divide the training process into two stages, reducing VRAM usage while accelerating training speed.\n\n(Split training is an experimental feature that has not yet undergone large-scale validation. If you encounter any issues while using it, please submit an issue on GitHub.)\n\n## Split Training\n\nIn the training process of most models, a large amount of computation occurs in \"preprocessing,\" i.e., \"computations unrelated to the denoising model,\" including VAE encoding, text encoding, etc. When the corresponding model parameters are fixed, the results of these computations are repetitive. For each data sample, the computational results are identical across multiple epochs. Therefore, we provide a \"split training\" feature that can automatically analyze and split the training process.\n\nFor standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](../Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation.\n\n## Computational Graph Splitting Algorithm\n\n> (We will supplement the detailed specifics of the computational graph splitting algorithm in future document updates)\n\n## Using Split Training\n\nSplit training already supports [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](../Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is:\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n```\n\nAfter splitting, in the first stage, make the following modifications:\n\n* Change `--dataset_repeat` to 1 to avoid redundant computation\n* Change `--output_path` to the path where the first-stage computation results are saved\n* Add the additional parameter `--task \"sft:data_process\"`\n* Remove the DiT model from `--model_id_with_origin_paths`\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:data_process\"\n```\n\nIn the second stage, make the following modifications:\n\n* Change `--dataset_base_path` to the `--output_path` of the first stage\n* Remove `--dataset_metadata_path`\n* Add the additional parameter `--task \"sft:train\"`\n* Remove the Text Encoder and VAE models from `--model_id_with_origin_paths`\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:train\"\n```\n\nWe provide sample training scripts and validation scripts located at `examples/qwen_image/model_training/special/split_training`.\n\n## Training Framework Design Concept\n\nThe training framework splits the computational units in the `Pipeline` through the `split_pipeline_units` method of `DiffusionTrainingModule`."
  },
  {
    "path": "docs/en/Training/Supervised_Fine_Tuning.md",
    "content": "# Standard Supervised Training\n\nAfter understanding the [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](../Pipeline_Usage/Model_Training.md).\n\nRecalling the model training pseudocode from earlier, when we actually write code, the situation becomes extremely complex. Some models require additional guidance conditions and preprocessing, such as ControlNet; some models require cross-computation with the denoising model, such as VACE; some models require Gradient Checkpointing due to excessive VRAM demands, such as Qwen-Image's DiT.\n\nTo achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms.\n\n## Framework Design Concept\n\nThe training module is encapsulated on top of the `Pipeline`, inheriting `DiffusionTrainingModule` from `diffsynth.diffusion.training_module`. We need to provide the necessary `__init__` and `forward` methods for the training module. Taking Qwen-Image's LoRA training as an example, we provide a simple script containing only basic training functions in `examples/qwen_image/model_training/special/simple/train.py` to help developers understand the design concept of the training module.\n\n```python\nclass QwenImageTrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        # Initialize models here.\n        pass\n\n    def forward(self, data):\n        # Compute loss here.\n        return loss\n```\n\n### `__init__`\n\nIn `__init__`, model initialization is required. First load the model, then switch it to training mode.\n\n```python\n    def __init__(self, device):\n        super().__init__()\n        # Load the pipeline\n        self.pipe = QwenImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n        )\n        # Switch to training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe,\n            lora_base_model=\"dit\",\n            lora_target_modules=\"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj\",\n            lora_rank=32,\n        )\n```\n\nThe logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](../Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](../Pipeline_Usage/VRAM_management.md).\n\n`switch_pipe_to_training_mode` can switch the model to training mode. See `switch_pipe_to_training_mode` for details.\n\n### `forward`\n\nIn `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](../Developer_Guide/Building_a_Pipeline.md#model_fn).\n\n```python\n    def forward(self, data):\n        # Preprocess\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": True,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        # Loss\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n```\n\nThe preprocessing process is consistent with the inference phase. Developers only need to assume they are using the `Pipeline` for inference and fill in the input parameters.\n\nThe loss function calculation reuses `FlowMatchSFTLoss` from `diffsynth.diffusion.loss`.\n\n### Starting Training\n\nThe training framework requires other modules, including:\n\n* accelerator: Training launcher provided by `accelerate`, see [`accelerate`](https://huggingface.co/docs/accelerate/index) for details\n* dataset: Generic dataset, see [`diffsynth.core.data`](../API_Reference/core/data.md) for details\n* model_logger: Model logger, see `diffsynth.diffusion.logger` for details\n\n```python\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)],\n    )\n    dataset = UnifiedDataset(\n        base_path=\"data/example_image_dataset\",\n        metadata_path=\"data/example_image_dataset/metadata.csv\",\n        repeat=50,\n        data_file_keys=\"image\",\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=\"data/example_image_dataset\",\n            height=512,\n            width=512,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = QwenImageTrainingModule(accelerator.device)\n    model_logger = ModelLogger(\n        output_path=\"models/toy_model\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=1e-5, num_epochs=1,\n    )\n```\n\nAssembling all the above code results in `examples/qwen_image/model_training/special/simple/train.py`. Use the following command to start training:\n\n```\naccelerate launch examples/qwen_image/model_training/special/simple/train.py\n```"
  },
  {
    "path": "docs/en/Training/Understanding_Diffusion_models.md",
    "content": "# Basic Principles of Diffusion Models\n\nThis document introduces the basic principles of Diffusion models to help you understand how the training framework is constructed. To make these complex mathematical theories easier for readers to understand, we have reconstructed the theoretical framework of Diffusion models, abandoning complex stochastic differential equations and presenting them in a more concise and understandable form.\n\n## Introduction\n\nDiffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.\n\n![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)\n\nThis process is intuitive, but to understand the details, we need to answer several questions:\n\n* How is the noise content at each step defined?\n* How is the iterative denoising computation performed?\n* How to train such Diffusion models?\n* What is the architecture of modern Diffusion models?\n* How does this project encapsulate and implement model training?\n\n## How is the noise content at each step defined?\n\nIn the theoretical system of Diffusion models, the noise content is determined by a series of parameters $\\sigma_T$, $\\sigma_{T-1}$, $\\sigma_{T-2}$, $\\cdots$, $\\sigma_0$. Where:\n\n* $\\sigma_T=1$, corresponding to $x_T$ as pure Gaussian noise\n* $\\sigma_T>\\sigma_{T-1}>\\sigma_{T-2}>\\cdots>x_0$, the noise content gradually decreases during iteration\n* $\\sigma_0=0$, corresponding to $x_0$ as a data sample without any noise\n\nAs for the intermediate values $\\sigma_{T-1}$, $\\sigma_{T-2}$, $\\cdots$, $\\sigma_1$, they are not fixed and only need to satisfy the decreasing condition.\n\nAt an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$.\n\n![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)\n\n## How is the iterative denoising computation performed?\n\nBefore understanding the iterative denoising computation, we need to clarify what the input and output of the denoising model are. We abstract the model as a symbol $\\hat \\epsilon$, whose input typically consists of three parts:\n\n* Time step $t$, the model needs to understand which stage of the denoising process it is currently in\n* Noisy data sample $x_t$, the model needs to understand what data to denoise\n* Guidance condition $c$, the model needs to understand what kind of data sample to generate through denoising\n\nAmong these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.\n\nThe model's output $\\hat \\epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).\n\nNext, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:\n\n$$\n\\begin{aligned}\nx_{t-1}&=x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot \\hat \\epsilon(x_t,c,t)\\\\\n&\\approx x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot (x_T-x_0)\\\\\n&=(1-\\sigma_t)x_0+\\sigma_t x_T + (\\sigma_{t-1} - \\sigma_t) \\cdot (x_T-x_0)\\\\\n&=(1-\\sigma_{t-1})x_0+\\sigma_{t-1}x_T\n\\end{aligned}\n$$\n\nPerfect! It perfectly matches the noise content definition at time step $t-1$.\n\n> (This part might be a bit difficult to understand. Don't worry; it's recommended to skip this part on first reading without affecting the rest of the document.)\n>\n> After completing this somewhat complex formula derivation, let's consider a question: why should the model's output approximately equal $x_T-x_0$? Can it be set to other values?\n>\n> Actually, Diffusion models rely on two definitions to form a complete theory. From the above formulas, we can extract these two definitions and derive the iterative formula:\n>\n> * Data definition: $x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$\n> * Model definition: $\\hat \\epsilon(x_t,c,t)=x_T-x_0$\n> * Derived iterative formula: $x_{t-1}=x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot \\hat \\epsilon(x_t,c,t)$\n>\n> These three mathematical formulas are complete. For example, in the previous derivation, substituting the data definition and model definition into the iterative formula yields $x_{t-1}$ that matches the data definition.\n>\n> These are two definitions built on Flow Matching theory, but Diffusion models can also be implemented with other definitions. For example, early models based on DDPM (Denoising Diffusion Probabilistic Models) have their two definitions and derived iterative formulas as:\n>\n> * Data definition: $x_t=\\sqrt{\\alpha_t}x_0+\\sqrt{1-\\alpha_t}x_T$\n> * Model definition: $\\hat \\epsilon(x_t,c,t)=x_T$\n> * Derived iterative formula: $x_{t-1}=\\sqrt{\\alpha_{t-1}}\\left(\\frac{x_t-\\sqrt{1-\\alpha_t}\\hat \\epsilon(x_t,c,t)}{\\sqrt{\\sigma_t}}\\right)+\\sqrt{1-\\alpha_{t-1}}\\hat \\epsilon(x_t,c,t)$\n>\n> More generally, we describe the derivation process of the iterative formula using matrices. For any data definition and model definition:\n>\n> * Data definition: $x_t=C_T(x_0,x_T)^T$\n> * Model definition: $\\hat \\epsilon(x_t,c,t)=C_T^{[\\epsilon]}(x_0,x_T)^T$\n> * Derived iterative formula: $x_{t-1}=C_{t-1}(C_t,C_t^{[\\epsilon]})^{-T}(x_t,\\hat \\epsilon(x_t,c,t))^T$\n>\n> Where $C_t$ and $C_t^{[\\epsilon]}$ are $1\\times 2$ coefficient matrices. It's not difficult to see that when constructing the two definitions, the matrix $(C_t,C_t^{[\\epsilon]})^T$ must be invertible.\n>\n> Although Flow Matching and DDPM have been widely verified by numerous pre-trained models, this doesn't mean they are optimal solutions. We encourage developers to design new Diffusion model theories for better training results.\n\n## How to train such Diffusion models?\n\nAfter understanding the iterative denoising process, we next consider how to train such Diffusion models.\n\nThe training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.\n\nThe following is pseudocode for the training process:\n\n> Obtain data sample $x_0$ and guidance condition $c$ from the dataset\n>\n> Randomly sample time step $t\\in(0,T]$\n>\n> Randomly sample Gaussian noise $x_T\\in \\mathcal N(O,I)$\n>\n> $x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$\n>\n> $\\hat \\epsilon(x_t,c,t)$\n>\n> Loss function $\\mathcal L=||\\hat \\epsilon(x_t,c,t)-(x_T-x_0)||_2^2$\n>\n> Backpropagate gradients and update model parameters\n\n## What is the architecture of modern Diffusion models?\n\nFrom theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the \"three-stage\" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.\n\n![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)\n\n### Data Encoder-Decoder\n\nIn the previous text, we consistently referred to $x_0$ as a \"data sample\" rather than an image or video because modern Diffusion models typically don't process images or videos directly. Instead, they use an Encoder-Decoder architecture model, usually a VAE (Variational Auto-Encoders) model, to encode images or videos into Embedding tensors, obtaining $x_0$.\n\nAfter data is encoded by the encoder and then decoded by the decoder, the reconstructed content is approximately consistent with the original, with minor errors. So why process on the encoded Embedding tensor instead of directly on images or videos? The main reasons are twofold:\n\n* Encoding compresses the data simultaneously, reducing computational load during processing.\n* Encoded data distribution is more similar to Gaussian distribution, making it easier for denoising models to model the data.\n\nDuring generation, the encoder part doesn't participate in computation. After iteration completes, the decoder part decodes $x_0$ to obtain clear images or videos. During training, the decoder part doesn't participate in computation; only the encoder is used to compute $x_0$.\n\n### Guidance Condition Encoder\n\nUser-input guidance conditions $c$ can be complex and diverse, requiring specialized encoder models to process them into Embedding tensors. According to the type of guidance condition, we classify guidance condition encoders into the following categories:\n\n* Text type, such as CLIP, Qwen-VL\n* Image type, such as ControlNet, IP-Adapter\n* Video type, such as VAE\n\n> The model $\\hat \\epsilon$ mentioned in the previous text refers to the entirety of all guidance condition encoders and the denoising model. We list guidance condition encoders separately because these models are typically frozen during Diffusion training, and their output values are independent of time step $t$, allowing guidance condition encoder computations to be performed offline.\n\n### Denoising Model\n\nThe denoising model is the true essence of Diffusion models, with diverse model structures such as UNet and DiT. Model developers can freely innovate on these structures.\n\n## How does this project encapsulate and implement model training?\n\nPlease read the next document: [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md)"
  },
  {
    "path": "docs/en/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\n# import sphinx_book_theme\n\nsys.path.insert(0, os.path.abspath('../../'))\n# -- Project information -----------------------------------------------------\n\nproject = 'diffsynth'\ncopyright = '2022-2025, Alibaba ModelScope'\nauthor = 'ModelScope Authors'\nversion_file = '../../diffsynth/version.py'\nhtml_theme = 'sphinx_rtd_theme'\nlanguage = 'en'\n\n\ndef get_version():\n    with open(version_file, 'r', encoding='utf-8') as f:\n        exec(compile(f.read(), version_file, 'exec'))\n    return locals()['__version__']\n\n\n# The full version, including alpha/beta/rc tags\nversion = get_version()\nrelease = version\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.viewcode',\n    'sphinx_markdown_tables',\n    'sphinx_copybutton',\n    \"sphinx_rtd_theme\",\n    'sphinx.ext.mathjax',\n    'myst_parser',\n]\n# build the templated autosummary files\nautosummary_generate = True\nnumpydoc_show_class_members = False\n\n# Enable overriding of function signatures in the first line of the docstring.\nautodoc_docstring_signature = True\n\n# Disable docstring inheritance\nautodoc_inherit_docstrings = False\n\n# Show type hints in the description\nautodoc_typehints = 'description'\n\n# Add parameter types if the parameter is documented in the docstring\nautodoc_typehints_description_target = 'documented_params'\n\nautodoc_default_options = {\n    'member-order': 'bysource',\n}\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = ['.rst', '.md']\n\n# The master toctree document.\nroot_doc = 'index'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['build']\n# A list of glob-style patterns [1] that are used to find source files.\n# They are matched against the source file names relative to the source directory,\n# using slashes as directory separators on all platforms.\n# The default is **, meaning that all files are recursively included from the source directory.\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\n# html_theme = 'sphinx_book_theme'\n# html_theme_path = [sphinx_book_theme.get_html_theme_path()]\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n# html_css_files = ['css/readthedocs.css']\n\n# -- Options for HTMLHelp output ---------------------------------------------\n# Output file base name for HTML help builder.\n\n# -- Extension configuration -------------------------------------------------\n# Ignore >>> when copying code\ncopybutton_prompt_text = r'>>> |\\.\\.\\. '\ncopybutton_prompt_is_regexp = True\n\n# Example configuration for intersphinx: refer to the Python standard library.\nintersphinx_mapping = {'https://docs.python.org/': None}\n\nmyst_enable_extensions = [\n    'amsmath',\n    'dollarmath',\n    'colon_fence',\n]\n"
  },
  {
    "path": "docs/en/index.rst",
    "content": "Welcome to DiffSynth-Studio's Documentation\n==========================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Documentation Introduction\n\n   README\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Getting Started\n\n   Pipeline_Usage/Setup\n   Pipeline_Usage/Model_Inference\n   Pipeline_Usage/VRAM_management\n   Pipeline_Usage/Model_Training\n   Pipeline_Usage/Environment_Variables\n   Pipeline_Usage/GPU_support\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Model Details\n\n   Model_Details/FLUX\n   Model_Details/Wan\n   Model_Details/Qwen-Image\n   Model_Details/FLUX2\n   Model_Details/Z-Image\n   Model_Details/Anima\n   Model_Details/LTX-2\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Training Framework\n\n   Training/Understanding_Diffusion_models\n   Training/Supervised_Fine_Tuning\n   Training/FP8_Precision\n   Training/Direct_Distill\n   Training/Split_Training\n   Training/Differential_LoRA\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Model Integration\n\n   Developer_Guide/Integrating_Your_Model\n   Developer_Guide/Building_a_Pipeline\n   Developer_Guide/Enabling_VRAM_management\n   Developer_Guide/Training_Diffusion_Models\n\n.. toctree::\n   :maxdepth: 2\n   :caption: API Reference\n\n   API_Reference/core/attention\n   API_Reference/core/data\n   API_Reference/core/gradient\n   API_Reference/core/loader\n   API_Reference/core/vram\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Research Guide\n\n   Research_Tutorial/train_from_scratch\n   Research_Tutorial/inference_time_scaling\n\n.. toctree::\n   :maxdepth: 2\n   :caption: FAQ\n\n   QA\n\nIndices and tables\n==================\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "docutils>=0.16.0\nmyst_parser\nrecommonmark\nsphinx>=5.3.0\nsphinx-book-theme\nsphinx-copybutton\nsphinx-autobuild\nsphinx-rtd-theme\nsphinx_markdown_tables\nsphinxcontrib-mermaid\npymdown-extensions"
  },
  {
    "path": "docs/zh/.readthedocs.yaml",
    "content": "# .readthedocs.yaml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the OS, Python version and other tools you might need\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\n# Build documentation in the \"docs/\" directory with Sphinx\nsphinx:\n  configuration: docs/zh/conf.py\n\n# Optionally build your docs in additional formats such as PDF and ePub\n# formats:\n#    - pdf\n#    - epub\n\n# Optional but recommended, declare the Python requirements required\n# to build your documentation\n# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html\npython:\n   install:\n      - requirements: docs/requirements.txt\n"
  },
  {
    "path": "docs/zh/API_Reference/core/attention.md",
    "content": "# `diffsynth.core.attention`: 注意力机制实现\n\n`diffsynth.core.attention` 提供了注意力机制实现的路由机制，根据 `Python` 环境中的可用包和[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。\n\n## 注意力机制\n\n注意力机制是在论文[《Attention Is All You Need》](https://arxiv.org/abs/1706.03762)中提出的模型结构，在原论文中，注意力机制按照如下公式实现：\n\n$$\n\\text{Attention}(Q, K, V) = \\text{Softmax}\\left(\n    \\frac{QK^T}{\\sqrt{d_k}}\n\\right)\nV.\n$$\n\n在 `PyTorch` 中，可以用如下代码实现：\n```python\nimport torch\n\ndef attention(query, key, value):\n    scale_factor = 1 / query.size(-1)**0.5\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    return attn_weight @ value\n\nquery = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nkey = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nvalue = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\noutput_1 = attention(query, key, value)\n```\n\n其中 `query`、`key`、`value` 的维度是 $(b, n, s, d)$：\n* $b$：Batch size\n* $n$: Attention head 的数量\n* $s$: 序列长度\n* $d$: 每个 Attention head 的维数\n\n这部分计算是不包含任何可训练参数的，现代 transformer 架构的模型会在进行这一计算前后经过 Linear 层，本文讨论的“注意力机制”不包含这些计算，仅包含以上代码的计算。\n\n## 更高效的实现\n\n注意到，注意力机制中 Attention Score（公式中的 $\\text{Softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)$，代码中的 `attn_weight`）的维度为 $(b, n, s, s)$，其中序列长度 $s$ 通常非常大，这导致计算的时间和空间复杂度达到平方级。以图像生成模型为例，图像的宽度和高度每增加到 2 倍，序列长度增加到 4 倍，计算量和显存需求增加到 16 倍。为了避免高昂的计算成本，需采用更高效的注意力机制实现，包括\n* Flash Attention 3：[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2407.08608)\n* Flash Attention 2：[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2307.08691)\n* Sage Attention：[GitHub](https://github.com/thu-ml/SageAttention)、[论文](https://arxiv.org/abs/2505.11594)\n* xFormers：[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)\n* PyTorch：[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)\n\n如需调用除 `PyTorch` 外的其他注意力实现，请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上，也可通过[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。\n\n```python\nfrom diffsynth.core.attention import attention_forward\nimport torch\n\ndef attention(query, key, value):\n    scale_factor = 1 / query.size(-1)**0.5\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    return attn_weight @ value\n\nquery = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nkey = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\nvalue = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device=\"cuda\")\noutput_1 = attention(query, key, value)\noutput_2 = attention_forward(query, key, value)\nprint((output_1 - output_2).abs().mean())\n```\n\n请注意，加速的同时会引入误差，但在大多数情况下误差是可以忽略不计的。\n\n## 开发者导引\n\n在为 `DiffSynth-Studio` 接入新模型时，开发者可自行决定是否调用 `diffsynth.core.attention` 中的 `attention_forward`，但我们期望模型能够尽可能优先调用这一模块，以便让新的注意力机制实现能够在这些模型上直接生效。\n\n## 最佳实践\n\n**在大多数情况下，我们建议直接使用 `PyTorch` 原生的实现，无需安装任何额外的包。** 虽然其他注意力机制实现可以加速，但加速效果是较为有限的，在少数情况下会出现兼容性和精度不足的问题。\n\n此外，高效的注意力机制实现会逐步集成到 `PyTorch` 中，`PyTorch` 的 `2.9.0` 版本中的 `scaled_dot_product_attention` 已经集成了 Flash Attention 2。我们仍在 `DiffSynth-Studio` 提供这一接口，是为了让一些激进的加速方案能够快速走向应用，尽管它们在稳定性上还需要时间验证。\n"
  },
  {
    "path": "docs/zh/API_Reference/core/data.md",
    "content": "# `diffsynth.core.data`: 数据处理算子与通用数据集\n\n## 数据处理算子\n\n### 可用数据处理算子\n\n`diffsynth.core.data` 提供了一系列数据处理算子，用于进行数据处理，包括：\n\n* 数据格式转换算子\n    * `ToInt`: 转换为 int 格式\n    * `ToFloat`: 转换为 float 格式\n    * `ToStr`: 转换为 str 格式\n    * `ToList`: 转换为列表格式，以列表包裹此数据\n    * `ToAbsolutePath`: 将相对路径转换为绝对路径\n* 文件加载算子\n    * `LoadImage`: 读取图片文件\n    * `LoadVideo`: 读取视频文件\n    * `LoadAudio`: 读取音频文件\n    * `LoadGIF`: 读取 GIF 文件\n    * `LoadTorchPickle`: 读取由 [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) 保存的二进制文件【该算子可能导致二进制文件中的代码注入攻击，请谨慎使用！】\n* 媒体文件处理算子\n    * `ImageCropAndResize`: 对图像进行裁剪和拉伸\n* Meta 算子\n    * `SequencialProcess`: 将序列中的每个数据路由到一个算子\n    * `RouteByExtensionName`: 按照文件扩展名路由到特定算子\n    * `RouteByType`: 按照数据类型路由到特定算子\n\n### 算子使用\n\n数据算子之间以 `>>` 符号连接形成数据处理流水线，例如：\n\n```python\nfrom diffsynth.core.data.operators import *\n\ndata = \"image.jpg\"\ndata_pipeline = ToAbsolutePath(base_path=\"/data\") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512)\ndata = data_pipeline(data)\n```\n\n在经过每个算子后，数据被依次处理\n\n* `ToAbsolutePath(base_path=\"/data\")`: `\"/data/image.jpg\"`\n* `LoadImage()`: `<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F8E7AAEFC10>`\n* `ImageCropAndResize(max_pixels=512*512)`: `<PIL.Image.Image image mode=RGB size=512x512 at 0x7F8E7A936F20>`\n\n我们可以组合出功能完备的数据流水线，例如通用数据集的默认视频数据算子为\n\n```python\nRouteByType(operator_map=[\n    (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[\n        ((\"jpg\", \"jpeg\", \"png\", \"webp\"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),\n        ((\"gif\",), LoadGIF(\n            num_frames, time_division_factor, time_division_remainder,\n            frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n        )),\n        ((\"mp4\", \"avi\", \"mov\", \"wmv\", \"mkv\", \"flv\", \"webm\"), LoadVideo(\n            num_frames, time_division_factor, time_division_remainder,\n            frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),\n        )),\n    ])),\n])\n```\n\n它包含如下逻辑：\n\n* 如果是 `str` 类型的数据\n    * 如果是 `\"jpg\", \"jpeg\", \"png\", \"webp\"` 类型文件\n        * 加载这张图片\n        * 裁剪并缩放到特定分辨率\n        * 打包进列表，视为单帧视频\n    * 如果是 `\"gif\"` 类型文件\n        * 加载 gif 文件内容\n        * 将每一帧裁剪和缩放到特定分辨率\n    * 如果是 `\"mp4\", \"avi\", \"mov\", \"wmv\", \"mkv\", \"flv\", \"webm\"` 类型文件\n        * 加载 gif 文件内容\n        * 将每一帧裁剪和缩放到特定分辨率\n* 如果不是 `str` 类型的数据，报错\n\n## 通用数据集\n\n`diffsynth.core.data` 提供了统一的数据集实现，数据集需输入以下参数：\n\n* `base_path`: 根目录，若数据集中包含图片文件的相对路径，则需填入此字段用于加载这些路径指向的文件\n* `metadata_path`: 元数据目录，记录所有元数据的文件路径，支持 `csv`、`json`、`jsonl` 格式\n* `repeat`: 数据重复次数，默认为 1，该参数影响一个 epoch 的训练步数\n* `data_file_keys`: 需进行加载的数据字段名，例如 `(image, edit_image)`\n* `main_data_operator`: 主加载算子，需通过数据处理算子组装好数据处理流水线\n* `special_operator_map`: 特殊算子映射，对需要特殊处理的字段构建的算子映射\n\n### 元数据\n\n数据集的 `metadata_path` 指向元数据文件，支持 `csv`、`json`、`jsonl` 格式，以下提供了样例\n\n* `csv` 格式：可读性高、不支持列表数据、内存占用小\n\n```csv\nimage,prompt\nimage_1.jpg,\"a dog\"\nimage_2.jpg,\"a cat\"\n```\n\n* `json` 格式：可读性高、支持列表数据、内存占用大\n\n```json\n[\n    {\n        \"image\": \"image_1.jpg\",\n        \"prompt\": \"a dog\"\n    },\n    {\n        \"image\": \"image_2.jpg\",\n        \"prompt\": \"a cat\"\n    }\n]\n```\n\n* `jsonl` 格式：可读性低、支持列表数据、内存占用小\n\n```json\n{\"image\": \"image_1.jpg\", \"prompt\": \"a dog\"}\n{\"image\": \"image_2.jpg\", \"prompt\": \"a cat\"}\n```\n\n如何选择最佳的元数据格式？\n\n* 如果数据量大，达到千万级的数据量，由于 `json` 文件解析时需要额外内存，此时不可用，请使用 `csv` 或 `jsonl` 格式\n* 如果数据集中包含列表数据，例如编辑模型需输入多张图，由于 `csv` 格式无法存储列表格式数据，此时不可用，请使用 `json` 或 `jsonl` 格式\n\n### 数据加载逻辑\n\n在没有进行额外设置时，数据集默认输出元数据集中的数据，图片和视频文件的路径会以字符串的格式输出，若要加载这些文件，则需要设置 `data_file_keys`、`main_data_operator`、`special_operator_map`。\n\n在数据处理流程中，按如下逻辑进行处理：\n* 如果字段位于 `special_operator_map`，则调用 `special_operator_map` 中的对应算子进行处理\n* 如果字段不位于 `special_operator_map`\n    * 如果字段位于 `data_file_keys`，则调用 `main_data_operator` 算子进行处理\n    * 如果字段不位于 `data_file_keys`，则不进行处理\n\n`special_operator_map` 可用于实现特殊的数据处理，例如模型 [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) 中输入的人物面部视频 `animate_face_video` 是以固定分辨率处理的，与输出视频不一致，因此这一字段由专门的算子处理：\n\n```python\nspecial_operator_map={\n    \"animate_face_video\": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),\n}\n```\n\n### 其他注意事项\n\n当数据量过少时，可适当增加 `repeat`，延长单个 epoch 的训练时间，避免频繁保存模型产生较多耗时。\n\n当数据量 * `repeat` 超过 $10^9$ 时，我们观测到数据集的速度明显变慢，这似乎是 `PyTorch` 的 bug，我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。\n"
  },
  {
    "path": "docs/zh/API_Reference/core/gradient.md",
    "content": "# `diffsynth.core.gradient`: 梯度检查点及其 Offload\n\n`diffsynth.core.gradient` 中提供了封装好的梯度检查点及其 Offload 版本，用于模型训练。\n\n## 梯度检查点\n\n梯度检查点是用于减少训练时显存占用的技术。我们提供一个例子来帮助你理解这一技术，以下是一个简单的模型结构\n\n```python\nimport torch\n\nclass ToyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        return self.activation(x)\n\nmodel = ToyModel()\nx = torch.randn((2, 3))\ny = model(x)\n```\n\n在这个模型结构中，输入的参数 $x$ 经过 Sigmoid 激活函数得到输出值 $y=\\frac{1}{1+e^{-x}}$。\n\n在训练过程中，假定我们的损失函数值为 $\\mathcal L$，在梯度反响传播时，我们得到 $\\frac{\\partial \\mathcal L}{\\partial y}$，此时我们需计算 $\\frac{\\partial \\mathcal L}{\\partial x}$，不难发现 $\\frac{\\partial y}{\\partial x}=y(1-y)$，进而有 $\\frac{\\partial \\mathcal L}{\\partial x}=\\frac{\\partial \\mathcal L}{\\partial y}\\frac{\\partial y}{\\partial x}=\\frac{\\partial \\mathcal L}{\\partial y}y(1-y)$。如果在模型前向传播时保存 $y$ 的数值，并在梯度反向传播时直接计算 $y(1-y)$，这将避免复杂的 exp 计算，加快计算速度，但这会导致我们需要额外的显存来存储中间变量 $y$。\n\n不启用梯度检查点时，训练框架会默认存储所有辅助梯度计算的中间变量，从而达到最佳的计算速度。开启梯度检查点时，中间变量则不会存储，但输入参数 $x$ 仍会存储，减少显存占用，在梯度反向传播时需重新计算这些变量，减慢计算速度。\n\n## 启用梯度检查点及其 Offload\n\n`diffsynth.core.gradient` 中的 `gradient_checkpoint_forward` 实现了梯度检查点及其 Offload，可参考以下代码调用：\n\n```python\nimport torch\nfrom diffsynth.core.gradient import gradient_checkpoint_forward\n\nclass ToyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        return self.activation(x)\n\nmodel = ToyModel()\nx = torch.randn((2, 3))\ny = gradient_checkpoint_forward(\n    model,\n    use_gradient_checkpointing=True,\n    use_gradient_checkpointing_offload=False,\n    x=x,\n)\n```\n\n* 当 `use_gradient_checkpointing=False` 且 `use_gradient_checkpointing_offload=False` 时，计算过程与原始计算完全相同，不影响模型的推理和训练，你可以直接将其集成到代码中。\n* 当 `use_gradient_checkpointing=True` 且 `use_gradient_checkpointing_offload=False` 时，启用梯度检查点。\n* 当 `use_gradient_checkpointing_offload=True` 时，启用梯度检查点，所有梯度检查点的输入参数存储在内存中，进一步降低显存占用和减慢计算速度。\n\n## 最佳实践\n\n> Q: 应当在何处启用梯度检查点？\n> \n> A: 对整个模型启用梯度检查点时，计算效率和显存占用并不是最优的，我们需要设置细粒度的梯度检查点，但同时不希望为框架增加过多繁杂的代码。因此我们建议在 `Pipeline` 的 `model_fn` 中实现，例如 `diffsynth/pipelines/qwen_image.py` 中的 `model_fn_qwen_image`，在 Block 层级启用梯度检查点，不需要修改模型结构的任何代码。\n\n> Q: 什么情况下需要启用梯度检查点？\n> \n> A: 随着模型参数量越来越大，梯度检查点已成为必要的训练技术，梯度检查点通常是需要启用的。梯度检查点的 Offload 则仅需在激活值占用显存过大的模型（例如视频生成模型）中启用。\n"
  },
  {
    "path": "docs/zh/API_Reference/core/loader.md",
    "content": "# `diffsynth.core.loader`: 模型下载与加载\n\n本文档介绍 `diffsynth.core.loader` 中模型下载与加载相关的功能。\n\n## ModelConfig\n\n`diffsynth.core.loader` 中的 `ModelConfig` 用于标注模型下载来源、本地路径、显存管理配置等信息。\n\n### 从远程下载并加载模型\n\n以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例，在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径，该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。\n\n默认情况下，即使模型已经下载完毕，程序仍会向远程查询是否有遗漏文件，如果要完全关闭远程请求，请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\",\n    origin_file_pattern=\"model.safetensors\",\n)\n# Download models\nconfig.download_if_necessary()\nprint(config.path)\n```\n\n调用 `download_if_necessary` 后，模型会自动下载，并将路径返回到 `config.path` 中。\n\n### 从本地路径加载模型\n\n如果从本地路径加载模型，则需要填入 `path`：\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(path=\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\")\n```\n\n如果模型包含多个分片文件，以列表的形式输入即可：\n\n```python\nfrom diffsynth.core import ModelConfig\n\nconfig = ModelConfig(path=[\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n])\n```\n\n### 显存管理配置\n\n`ModelConfig` 也包含了显存管理配置信息，详见[显存管理](../../Pipeline_Usage/VRAM_management.md#更多使用方式)。\n\n## 模型文件加载\n\n`diffsynth.core.loader` 提供了统一的 `load_state_dict`，用于加载模型文件中的 state dict。\n\n加载单个模型文件：\n\n```python\nfrom diffsynth.core import load_state_dict\n\nstate_dict = load_state_dict(\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\")\n```\n\n加载多个模型文件（合并为一个 state dict）：\n\n```python\nfrom diffsynth.core import load_state_dict\n\nstate_dict = load_state_dict([\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n])\n```\n\n## 模型哈希\n\n模型哈希是用于判断模型类型的，哈希值可通过 `hash_model_file` 获取：\n\n```python\nfrom diffsynth.core import hash_model_file\n\nprint(hash_model_file(\"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"))\n```\n\n也可计算多个模型文件的哈希值，等价于合并 state dict 后计算模型哈希值：\n\n```python\nfrom diffsynth.core import hash_model_file\n\nprint(hash_model_file([\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n]))\n```\n\n模型哈希值只与模型文件中 state dict 的 keys 和 tensor shape 有关，与模型参数的数值、文件保存时间等信息无关。在计算 `.safetensors` 格式文件的模型哈希值时，`hash_model_file` 是几乎瞬间完成的，无需读取模型的参数；但在计算 `.bin`、`.pth`、`.ckpt` 等二进制文件的模型哈希值时，则需要读取全部模型参数，因此**我们不建议开发者继续使用这些格式的文件。**\n\n通过[编写模型 Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`，开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。\n\n## 模型加载\n\n`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口，它会调用 [skip_model_initialization](../../API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload)，则调用 [DiskMap](../../API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载；如果没有启用 Disk Offload，则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话，还会调用 [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。\n\n以下是一个启用了 Disk Offload 的使用案例：\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\n\nmodel = load_model(\n    QwenImageDiT,\n    model_path,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config={\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\n```\n"
  },
  {
    "path": "docs/zh/API_Reference/core/vram.md",
    "content": "# `diffsynth.core.vram`: 显存管理\n\n本文档介绍 `diffsynth.core.vram` 中的显存管理底层功能，如果你希望将这些功能用于其他的代码库中，可参考本文档。\n\n## 跳过模型参数初始化\n\n在 `PyTorch` 中加载模型时，模型的参数默认会占用显存或内存并进行参数初始化，而这些参数会在加载预训练权重后被覆盖掉，这导致了冗余的计算。`PyTorch` 中没有提供接口来跳过这些冗余的计算，我们在 `diffsynth.core.vram` 中提供了 `skip_model_initialization` 用于跳过模型参数初始化。\n\n默认的模型加载方式：\n\n```python\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\n\nmodel = QwenImageBlockWiseControlNet() # Slow\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\")\nmodel.load_state_dict(state_dict, assign=True)\n```\n\n跳过参数初始化的模型加载方式：\n\n```python\nfrom diffsynth.core import load_state_dict, skip_model_initialization\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\n\nwith skip_model_initialization():\n    model = QwenImageBlockWiseControlNet() # Fast\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\")\nmodel.load_state_dict(state_dict, assign=True)\n```\n\n在 `DiffSynth-Studio` 中，所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](../../Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。\n\n## State Dict 硬盘映射\n\n对于某个模型的预训练权重文件，如果我们只需要读取其中的一组参数，而非全部参数，State Dict 硬盘映射可以加速这一过程。我们在 `diffsynth.core.vram` 中提供了 `DiskMap` 用于按需加载模型参数。\n\n默认的权重加载方式：\n\n```python\nfrom diffsynth.core import load_state_dict\n\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = load_state_dict(path, device=\"cpu\") # Slow\nprint(state_dict[\"img_in.weight\"])\n```\n\n使用 `DiskMap` 只加载特定参数：\n\n```python\nfrom diffsynth.core import DiskMap\n\npath = \"models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors\"\nstate_dict = DiskMap(path, device=\"cpu\") # Fast\nprint(state_dict[\"img_in.weight\"])\n```\n\n`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件，开发者在[配置细粒度显存管理方案](../../Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。\n\n`DiskMap` 是利用 `.safetensors` 文件的特性实现的功能，因此在使用 `.bin`、`.pth`、`.ckpt` 等二进制文件时，模型的参数是全量加载的，这也导致 Disk Offload 不支持这些格式的文件。**我们不建议开发者继续使用这些格式的文件。**\n\n## 显存管理可替换模块\n\n在启用 `DiffSynth-Studio` 的显存管理后，模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块，其使用方式详见[细粒度显存管理方案](../../Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。\n"
  },
  {
    "path": "docs/zh/Developer_Guide/Building_a_Pipeline.md",
    "content": "# 接入 Pipeline\n\n在[将 Pipeline 所需的模型接入](../Developer_Guide/Integrating_Your_Model.md)之后，还需构建 `Pipeline` 用于模型推理，本文档提供 `Pipeline` 构建的标准化流程，开发者也可参考现有的 `Pipeline` 进行构建。\n\n`Pipeline` 的实现位于 `diffsynth/pipelines`，每个 `Pipeline` 包含以下必要的关键组件：\n\n* `__init__`\n* `from_pretrained`\n* `__call__`\n* `units`\n* `model_fn`\n\n## `__init__`\n\n在 `__init__` 中，`Pipeline` 进行初始化，以下是一个简易的实现：\n\n```python\nimport torch\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom ..diffusion import FlowMatchScheduler\nfrom ..core import ModelConfig\nfrom ..diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model\n\nclass NewDiffSynthPipeline(BasePipeline):\n\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(device=device, torch_dtype=torch_dtype)\n        self.scheduler = FlowMatchScheduler()\n        self.text_encoder: XXX_Model = None\n        self.dit: YYY_Model = None\n        self.vae: ZZZ_Model = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            NewDiffSynthPipelineUnit_xxx(),\n            ...\n        ]\n        self.model_fn = model_fn_new\n```\n\n其中包括以下几部分\n\n* `scheduler`: 调度器，用于控制推理的迭代公式中的系数，控制每一步的噪声含量。\n* `text_encoder`、`dit`、`vae`: 模型，自 [Latent Diffusion](https://arxiv.org/abs/2112.10752) 被提出以来，这种三段式模型架构已成为主流的 Diffusion 模型架构，但这并不是一成不变的，`Pipeline` 中可添加任意多个模型。\n* `in_iteration_models`: 迭代中模型，这个元组标注了在迭代中会调用哪些模型。\n* `units`: 模型迭代的前处理单元，详见[`units`](#units)。\n* `model_fn`: 迭代中去噪模型的 `forward` 函数，详见[`model_fn`](#model_fn)。\n\n> Q: 模型加载并不发生在 `__init__`，为什么这里仍要将每个模型初始化为 `None`？\n> \n> A: 在这里标注每个模型的类型后，代码编辑器就可以根据每个模型提供代码补全提示，便于后续的开发。\n\n## `from_pretrained`\n\n`from_pretrained` 负责加载所需的模型，让 `Pipeline` 变成可调用的状态。以下是一个简易的实现：\n\n```python\n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"xxx_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"yyy_dit\")\n        pipe.vae = model_pool.fetch_model(\"zzz_vae\")\n        # If necessary, load tokenizers here.\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n```\n\n开发者需要实现其中获取模型的逻辑，对应的模型名称即为[模型接入时填写的模型 Config](../Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config) 中的 `\"model_name\"`。\n\n部分模型还需要加载 `tokenizer`，可根据需要在 `from_pretrained` 上添加额外的 `tokenizer_config` 参数并在获取模型后实现这部分。\n\n## `__call__`\n\n`__call__` 实现了整个 Pipeline 的生成过程，以下是常见的生成过程模板，开发者可根据需要在此基础上修改。\n\n```python\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 4.0,\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        height: int = 1328,\n        width: int = 1328,\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        num_inference_steps: int = 30,\n        progress_bar_cmd = tqdm,\n    ):\n        # Scheduler\n        self.scheduler.set_timesteps(\n            num_inference_steps,\n            denoising_strength=denoising_strength\n        )\n        \n        # Parameters\n        inputs_posi = {\n            \"prompt\": prompt,\n        }\n        inputs_nega = {\n            \"negative_prompt\": negative_prompt,\n        }\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image,\n            \"denoising_strength\": denoising_strength,\n            \"height\": height,\n            \"width\": width,\n            \"seed\": seed,\n            \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n\n            # Inference\n            noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)\n            if cfg_scale != 1.0:\n                noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)\n                noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)\n            else:\n                noise_pred = noise_pred_posi\n\n            # Scheduler\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"], device=self.device)\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n```\n\n## `units`\n\n`units` 包含了所有的前处理过程，例如：宽高检查、提示词编码、初始噪声生成等。在整个模型前处理过程中，数据被抽象为了互斥的三部分，分别存储在对应的字典中：\n\n* `inputs_shard`: 共享输入，与 [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598)（简称 CFG）无关的参数。\n* `inputs_posi`: Classifier-Free Guidance 的 Positive 侧输入，包含与正向提示词相关的内容。\n* `inputs_nega`: Classifier-Free Guidance 的 Negative 侧输入，包含与负向提示词相关的内容。\n\nPipeline Unit 的实现包括三种：直接模式、CFG 分离模式、接管模式。\n\n如果某些计算与 CFG 无关，可采用直接模式，例如 Qwen-Image 的随机噪声初始化：\n\n```python\nclass QwenImageUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n```\n\n如果某些计算与 CFG 有关，需分别处理正向和负向提示词，但两侧的输入参数是相同的，可采用 CFG 分离模式，例如 Qwen-image 的提示词编码：\n\n```python\nclass QwenImageUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            input_params=(\"edit_image\",),\n            output_params=(\"prompt_emb\", \"prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:\n        pipe.load_models_to_device(self.onload_model_names)\n        # Do something\n        return {\"prompt_emb\": prompt_embeds, \"prompt_emb_mask\": encoder_attention_mask}\n```\n\n如果某些计算需要全局的信息，则需要接管模式，例如 Qwen-Image 的实体分区控制：\n\n```python\nclass QwenImageUnit_EntityControl(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            take_over=True,\n            input_params=(\"eligen_entity_prompts\", \"width\", \"height\", \"eligen_enable_on_negative\", \"cfg_scale\"),\n            output_params=(\"entity_prompt_emb\", \"entity_masks\", \"entity_prompt_emb_mask\"),\n            onload_model_names=(\"text_encoder\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):\n        # Do something\n        return inputs_shared, inputs_posi, inputs_nega\n```\n\n以下是 Pipeline Unit 所需的参数配置：\n\n* `seperate_cfg`: 是否启用 CFG 分离模式\n* `take_over`: 是否启用接管模式\n* `input_params`: 共享输入参数\n* `output_params`: 输出参数\n* `input_params_posi`: Positive 侧输入参数\n* `input_params_nega`: Negative 侧输入参数\n* `onload_model_names`: 需调用的模型组件名\n\n在设计 `unit` 时请尽量按照以下原则进行：\n\n* 缺省兜底：可选功能的 `unit` 输入参数默认为 `None`，而不是 `False` 或其他数值，请对此默认值进行兜底处理。\n* 参数触发：部分 Adapter 模型可能是未被加载的，例如 ControlNet，对应的 `unit` 应当以参数输入是否为 `None` 来控制触发，而不是以模型是否被加载来控制触发。例如当用户输入了 `controlnet_image` 但没有加载 ControlNet 模型时，代码应当给出报错，而不是忽略这些输入参数继续执行。\n* 简洁优先：尽可能使用直接模式，仅当功能无法实现时，使用接管模式。\n* 显存高效：在 `unit` 中调用模型时，请使用 `pipe.load_models_to_device(self.onload_model_names)` 激活对应的模型，请不要调用 `onload_model_names` 之外的其他模型，`unit` 计算完成后，请不要使用 `pipe.load_models_to_device([])` 手动释放显存。\n\n> Q: 部分参数并未在推理过程中调用，例如 `output_params`，是否仍有必要配置？\n> \n> A: 这些参数不会影响推理过程，但会影响一些实验性功能，因此我们建议将其配置好。例如“拆分训练”，我们可以将训练中的前处理离线完成，但部分需要梯度回传的模型计算无法拆分，这些参数用于构建计算图从而推断哪些计算是可以拆分的。\n\n## `model_fn`\n\n`model_fn` 是迭代中的统一 `forward` 接口，对于开源模型生态尚未形成的模型，直接沿用去噪模型的 `forward` 即可，例如：\n\n```python\ndef model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):\n    return dit(latents, prompt_emb, timestep)\n```\n\n对于开源生态丰富的模型，`model_fn` 通常包含复杂且混乱的跨模型推理，以 `diffsynth/pipelines/qwen_image.py` 为例，这个函数中实现的额外计算包括：实体分区控制、三种 ControlNet、Gradient Checkpointing 等，开发者在实现这一部分时要格外小心，避免模块功能之间的冲突。\n"
  },
  {
    "path": "docs/zh/Developer_Guide/Enabling_VRAM_management.md",
    "content": "# 细粒度显存管理方案\n\n本文档介绍如何为模型编写合理的细粒度显存管理方案，以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库，在阅读本文档前，请先阅读文档[显存管理](../Pipeline_Usage/VRAM_management.md)。\n\n## 20B 模型需要多少显存？\n\n以 Qwen-Image 的 DiT 模型为例，这一模型的参数量达到了 20B，以下代码会加载这一模型并进行推理，需要约 40G 显存，这个模型在显存较小的消费级 GPU 上显然是无法运行的。\n\n```python\nfrom diffsynth.core import load_model\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT\nfrom modelscope import snapshot_download\nimport torch\n\nsnapshot_download(\n    model_id=\"Qwen/Qwen-Image\",\n    local_dir=\"models/Qwen/Qwen-Image\",\n    allow_file_pattern=\"transformer/*\"\n)\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cuda\")\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\n## 编写细粒度显存管理方案\n\n为了编写细粒度的显存管理方案，我们需用 `print(model)` 观察和分析模型结构：\n\n```\nQwenImageDiT(\n  (pos_embed): QwenEmbedRope()\n  (time_text_embed): TimestepEmbeddings(\n    (time_proj): TemporalTimesteps()\n    (timestep_embedder): DiffusersCompatibleTimestepProj(\n      (linear_1): Linear(in_features=256, out_features=3072, bias=True)\n      (act): SiLU()\n      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)\n    )\n  )\n  (txt_norm): RMSNorm()\n  (img_in): Linear(in_features=64, out_features=3072, bias=True)\n  (txt_in): Linear(in_features=3584, out_features=3072, bias=True)\n  (transformer_blocks): ModuleList(\n    (0-59): 60 x QwenImageTransformerBlock(\n      (img_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (attn): QwenDoubleStreamAttention(\n        (to_q): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_k): Linear(in_features=3072, out_features=3072, bias=True)\n        (to_v): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_q): RMSNorm()\n        (norm_k): RMSNorm()\n        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)\n        (norm_added_q): RMSNorm()\n        (norm_added_k): RMSNorm()\n        (to_out): Sequential(\n          (0): Linear(in_features=3072, out_features=3072, bias=True)\n        )\n        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)\n      )\n      (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (img_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n      (txt_mod): Sequential(\n        (0): SiLU()\n        (1): Linear(in_features=3072, out_features=18432, bias=True)\n      )\n      (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n      (txt_mlp): QwenFeedForward(\n        (net): ModuleList(\n          (0): ApproximateGELU(\n            (proj): Linear(in_features=3072, out_features=12288, bias=True)\n          )\n          (1): Dropout(p=0.0, inplace=False)\n          (2): Linear(in_features=12288, out_features=3072, bias=True)\n        )\n      )\n    )\n  )\n  (norm_out): AdaLayerNorm(\n    (linear): Linear(in_features=3072, out_features=6144, bias=True)\n    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)\n  )\n  (proj_out): Linear(in_features=3072, out_features=64, bias=True)\n)\n```\n\n在显存管理中，我们只关心包含参数的 Layer。在这个模型结构中，`QwenEmbedRope`、`TemporalTimesteps`、`SiLU` 等 Layer 都是不包含参数的，`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear` 和 `RMSNorm`。\n\n`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理：\n* `AutoWrappedLinear`: 用于替换 `Linear` 层\n* `AutoWrappedModule`: 用于替换其他任意层\n\n编写一个 `module_map`，将模型中的 `Linear` 和 `RMSNorm` 映射到对应的模块上：\n\n```python\nmodule_map={\n    torch.nn.Linear: AutoWrappedLinear,\n    RMSNorm: AutoWrappedModule,\n}\n```\n\n此外，还需要提供 `vram_config` 与 `vram_limit`，这两个参数在[显存管理](../Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。\n\n调用 `enable_vram_management` 即可启用显存管理，注意此时模型加载时的 `device` 为 `cpu`，与 `offload_device` 一致：\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device=\"cpu\")\nenable_vram_management(\n    model,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config = {\n        \"offload_dtype\": torch.bfloat16,\n        \"offload_device\": \"cpu\",\n        \"onload_dtype\": torch.bfloat16,\n        \"onload_device\": \"cpu\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\n以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`。\n\n## Disk Offload\n\n[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案，需在模型加载过程中启用，而非模型加载完毕后。通常，在以上代码能够顺利运行的前提下，Disk Offload 可以直接启用：\n\n```python\nfrom diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule\nfrom diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm\nimport torch\n\nprefix = \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model\"\nmodel_path = [prefix + f\"-0000{i}-of-00009.safetensors\" for i in range(1, 10)]\ninputs = {\n    \"latents\": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device=\"cuda\"),\n    \"timestep\": torch.zeros((1,), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb\": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device=\"cuda\"),\n    \"prompt_emb_mask\": torch.ones((1, 5), dtype=torch.int64, device=\"cuda\"),\n    \"height\": 1024,\n    \"width\": 1024,\n}\n\nmodel = load_model(\n    QwenImageDiT,\n    model_path,\n    module_map={\n        torch.nn.Linear: AutoWrappedLinear,\n        RMSNorm: AutoWrappedModule,\n    },\n    vram_config={\n        \"offload_dtype\": \"disk\",\n        \"offload_device\": \"disk\",\n        \"onload_dtype\": \"disk\",\n        \"onload_device\": \"disk\",\n        \"preparing_dtype\": torch.bfloat16,\n        \"preparing_device\": \"cuda\",\n        \"computation_dtype\": torch.bfloat16,\n        \"computation_device\": \"cuda\",\n    },\n    vram_limit=0,\n)\nwith torch.no_grad():\n    output = model(**inputs)\n```\n\nDisk Offload 是极为特殊的显存管理方案，只支持 `.safetensors` 格式文件，不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件，不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。\n\n如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况，请在 GitHub 上给我们提 issue。\n\n## 写入默认配置\n\n为了让用户能够更方便地使用显存管理功能，我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中，上述模型的配置信息为：\n\n```python\n\"diffsynth.models.qwen_image_dit.QwenImageDiT\": {\n    \"diffsynth.models.qwen_image_dit.RMSNorm\": \"diffsynth.core.vram.layers.AutoWrappedModule\",\n    \"torch.nn.Linear\": \"diffsynth.core.vram.layers.AutoWrappedLinear\",\n}\n```\n"
  },
  {
    "path": "docs/zh/Developer_Guide/Integrating_Your_Model.md",
    "content": "# 接入模型结构\n\n本文档介绍如何将模型接入到 `DiffSynth-Studio` 框架中，供 `Pipeline` 等模块调用。\n\n## Step 1: 集成模型结构代码\n\n`DiffSynth-Studio` 中的所有模型结构实现统一在 `diffsynth/models` 中，每个 `.py` 代码文件分别实现一个模型结构，所有模型通过 `diffsynth/models/model_loader.py` 中的 `ModelPool` 来加载。在接入新的模型结构时，请在这个路径下建立新的 `.py` 文件。\n\n```shell\ndiffsynth/models/\n├── general_modules.py\n├── model_loader.py\n├── qwen_image_controlnet.py\n├── qwen_image_dit.py\n├── qwen_image_text_encoder.py\n├── qwen_image_vae.py\n└── ...\n```\n\n在大多数情况下，我们建议用 `PyTorch` 原生代码的形式集成模型，让模型结构类直接继承 `torch.nn.Module`，例如：\n\n```python\nimport torch\n\nclass NewDiffSynthModel(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.linear = torch.nn.Linear(dim, dim)\n        self.activation = torch.nn.Sigmoid()\n    \n    def forward(self, x):\n        x = self.linear(x)\n        x = self.activation(x)\n        return x\n```\n\n如果模型结构的实现中包含额外的依赖，我们强烈建议将其删除，否则这会导致沉重的包依赖问题。在我们现有的模型中，Qwen-Image 的 Blockwise ControlNet 是以这种方式集成的，其代码很轻量，请参考 `diffsynth/models/qwen_image_controlnet.py`。\n\n如果模型已被 Huggingface Library （[`transformers`](https://huggingface.co/docs/transformers/main/index)、[`diffusers`](https://huggingface.co/docs/diffusers/main/index) 等）集成，我们能够以更简单的方式集成模型：\n\n<details>\n<summary>集成 Huggingface Library 风格模型结构代码</summary>\n\n这类模型在 Huggingface Library 中的加载方式为：\n\n```python\nfrom transformers import XXX_Model\n\nmodel = XXX_Model.from_pretrained(\"path_to_your_model\")\n```\n\n`DiffSynth-Studio` 不支持通过 `from_pretrained` 加载模型，因为这与显存管理等功能是冲突的，请将模型结构改写成以下格式：\n\n```python\nimport torch\n\nclass DiffSynth_XXX_Model(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        from transformers import XXX_Config, XXX_Model\n        config = XXX_Config(**{\n            \"architectures\": [\"XXX_Model\"],\n            \"other_configs\": \"Please copy and paste the other configs here.\",\n        })\n        self.model = XXX_Model(config)\n        \n    def forward(self, x):\n        outputs = self.model(x)\n        return outputs\n```\n\n其中 `XXX_Config` 为模型对应的 Config 类，例如 `Qwen2_5_VLModel` 的 Config 类为 `Qwen2_5_VLConfig`，可通过查阅其源代码找到。Config 内部的内容通常可以在模型库中的 `config.json` 中找到，`DiffSynth-Studio` 不会读取 `config.json` 文件，因此需要将其中的内容复制粘贴到代码中。\n\n在少数情况下，`transformers` 和 `diffusers` 的版本更新会导致部分的模型无法导入，因此如果可能的话，我们仍建议使用 Step 1.1 中的模型集成方式。\n\n在我们现有的模型中，Qwen-Image 的 Text Encoder 是以这种方式集成的，其代码很轻量，请参考 `diffsynth/models/qwen_image_text_encoder.py`。\n\n</details>\n\n## Step 2: 模型文件格式转换\n\n由于开源社区中开发者提供的模型文件格式多种多样，因此我们有时需对模型文件格式进行转换，从而形成格式正确的 [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)，常见于以下几种情况：\n\n* 模型文件由不同代码库构建，例如 [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 和 [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)。\n* 模型在接入中做了修改，例如 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 的 Text Encoder 在 `diffsynth/models/qwen_image_text_encoder.py` 中增加了 `model.` 前缀。\n* 模型文件包含多个模型，例如 [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) 的 VACE Adapter 和基础 DiT 模型混合存储在同一组模型文件中。\n\n在我们的开发理念中，我们希望尽可能尊重模型原作者的意愿。如果对模型文件进行重新封装，例如 [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI)，虽然我们可以更方便地调用模型，但流量（模型页面浏览量和下载量等）会被引向他处，模型的原作者也会失去删除模型的权力。因此，我们在框架中增加了 `diffsynth/utils/state_dict_converters` 这一模块，用于在模型加载过程中进行文件格式转换。\n\n这部分逻辑是非常简单的，以 Qwen-Image 的 Text Encoder 为例，只需要 10 行代码即可：\n\n```python\ndef QwenImageTextEncoderStateDictConverter(state_dict):\n    state_dict_ = {}\n    for k in state_dict:\n        v = state_dict[k]\n        if k.startswith(\"visual.\"):\n            k = \"model.\" + k\n        elif k.startswith(\"model.\"):\n            k = k.replace(\"model.\", \"model.language_model.\")\n        state_dict_[k] = v\n    return state_dict_\n```\n\n## Step 3: 编写模型 Config\n\n模型 Config 位于 `diffsynth/configs/model_configs.py`，用于识别模型类型并进行加载。需填入以下字段：\n\n* `model_hash`：模型文件哈希值，可通过 `hash_model_file` 函数获取，此哈希值仅与模型文件中 state dict 的 keys 和张量 shape 有关，与文件中的其他信息无关。\n* `model_name`: 模型名称，用于给 `Pipeline` 识别所需模型。如果不同结构的模型在 `Pipeline` 中发挥的作用相同，则可以使用相同的 `model_name`。在接入新模型时，只需保证 `model_name` 与现有的其他功能模型不同即可。在 `Pipeline` 的 `from_pretrained` 中通过 `model_name` 获取对应的模型。\n* `model_class`: 模型结构导入路径，指向在 Step 1 中实现的模型结构类，例如 `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`。\n* `state_dict_converter`: 可选参数，如需进行模型文件格式转换，则需填入模型转换逻辑的导入路径，例如 `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`。\n* `extra_kwargs`: 可选参数，如果模型初始化时需传入额外参数，则需要填入这些参数，例如模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 与 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) 都采用了 `diffsynth/models/qwen_image_controlnet.py` 中的 `QwenImageBlockWiseControlNet` 结构，但后者还需额外的配置 `additional_in_dim=4`，因此这部分配置信息需填入 `extra_kwargs` 字段。\n\n我们提供了一份代码，以便快速理解模型是如何通过这些配置信息加载的：\n\n```python\nfrom diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization\nfrom diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder\nfrom diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter\nimport torch\n\nmodel_hash = \"8004730443f55db63092006dd9f7110e\"\nmodel_name = \"qwen_image_text_encoder\"\nmodel_class = QwenImageTextEncoder\nstate_dict_converter = QwenImageTextEncoderStateDictConverter\nextra_kwargs = {}\n\nmodel_path = [\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n]\nif hash_model_file(model_path) == model_hash:\n    with skip_model_initialization():\n        model = model_class(**extra_kwargs)\n    state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device=\"cuda\")\n    state_dict = state_dict_converter(state_dict)\n    model.load_state_dict(state_dict, assign=True)\n    print(\"Done!\")\n```\n\n> Q: 上述代码的逻辑看起来很简单，为什么 `DiffSynth-Studio` 中的这部分代码极为复杂？\n> \n> A: 因为我们提供了激进的显存管理功能，与模型加载逻辑耦合，这导致框架结构的复杂性，我们已尽可能简化暴露给开发者的接口。\n\n`diffsynth/configs/model_configs.py` 中的 `model_hash` 不是唯一存在的，同一模型文件中可能存在多个模型。对于这种情况，请使用多个模型 Config 分别加载每个模型，编写相应的 `state_dict_converter` 分离每个模型所需的参数。\n\n## Step 4: 检验模型是否能被识别和加载\n\n模型接入之后，可通过以下代码验证模型是否能够被正确识别和加载，以下代码会试图将模型加载到内存中：\n\n```python\nfrom diffsynth.models.model_loader import ModelPool\n\nmodel_pool = ModelPool()\nmodel_pool.auto_load_model(\n    [\n        \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n        \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n    ],\n)\n```\n\n如果模型能够被识别和加载，则会看到以下输出内容：\n\n```\nLoading models from: [\n    \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n    \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n]\nLoaded model: {\n    \"model_name\": \"qwen_image_text_encoder\",\n    \"model_class\": \"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder\",\n    \"extra_kwargs\": null\n}\n```\n\n## Step 5: 编写模型显存管理方案\n\n`DiffSynth-Studio` 支持复杂的显存管理，详见[启用显存管理](../Developer_Guide/Enabling_VRAM_management.md)。\n"
  },
  {
    "path": "docs/zh/Developer_Guide/Training_Diffusion_Models.md",
    "content": "# 接入模型训练\n\n在[接入模型](../Developer_Guide/Integrating_Your_Model.md)并[实现 Pipeline](../Developer_Guide/Building_a_Pipeline.md)后，接下来接入模型训练功能。\n\n## 训推一致的 Pipeline 改造\n\n为了保证训练和推理过程严格的一致性，我们会在训练过程中沿用大部分推理代码，但仍需作出少量改造。\n\n首先，在推理过程中添加额外的逻辑，让图生图/视频生视频逻辑根据 `scheduler` 状态进行切换。以 Qwen-Image 为例：\n\n```python\nclass QwenImageUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\", \"tiled\", \"tile_size\", \"tile_stride\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)\n        input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n```\n\n然后，在 `model_fn` 中启用 Gradient Checkpointing，这将以计算速度为代价，大幅度减少训练所需的显存。这并不是必需的，但我们强烈建议这么做。\n\n以 Qwen-Image 为例，修改前：\n\n```python\ntext, image = block(\n    image=image,\n    text=text,\n    temb=conditioning,\n    image_rotary_emb=image_rotary_emb,\n    attention_mask=attention_mask,\n)\n```\n\n修改后：\n\n```python\nfrom ..core import gradient_checkpoint_forward\n\ntext, image = gradient_checkpoint_forward(\n    block,\n    use_gradient_checkpointing,\n    use_gradient_checkpointing_offload,\n    image=image,\n    text=text,\n    temb=conditioning,\n    image_rotary_emb=image_rotary_emb,\n    attention_mask=attention_mask,\n)\n```\n\n## 编写训练脚本\n\n`DiffSynth-Studio` 没有对训练框架做严格的封装，而是将脚本内容暴露给开发者，这种方式可以更方便地对训练脚本进行修改，实现额外的功能。开发者可参考现有的训练脚本，例如 `examples/qwen_image/model_training/train.py` 进行修改，从而适配新的模型训练。\n"
  },
  {
    "path": "docs/zh/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/zh/Model_Details/Anima.md",
    "content": "# Anima\n\nAnima 是由 CircleStone Labs 与 Comfy Org 训练并开源的图像生成模型。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n## 模型总览\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|\n\n特殊训练脚本：\n\n* 差分 LoRA 训练：[doc](../Training/Differential_LoRA.md)\n* FP8 精度训练：[doc](../Training/FP8_Precision.md)\n* 两阶段拆分训练：[doc](../Training/Split_Training.md)\n* 端到端直接蒸馏：[doc](../Training/Direct_Distill.md)\n\n## 模型推理\n\n模型通过 `AnimaImagePipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`AnimaImagePipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述画面中出现的内容。\n* `negative_prompt`: 负向提示词，描述画面中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 4.0。\n* `input_image`: 输入图像，用于图像到图像的生成。默认为 `None`。\n* `denoising_strength`: 去噪强度，控制生成图像与输入图像的相似度，默认值为 1.0。\n* `height`: 图像高度，需保证高度为 16 的倍数，默认值为 1024。\n* `width`: 图像宽度，需保证宽度为 16 的倍数，默认值为 1024。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 30。\n* `sigma_shift`: 调度器的 sigma 偏移量，默认为 `None`。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"模型总览\"中的表格。\n\n## 模型训练\n\nAnima 系列模型统一通过 [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"anima-team/anima-1B:text_encoder/*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 图像宽高配置（适用于图像生成模型和视频生成模型）\n        * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n* Anima 专有参数\n    * `--tokenizer_path`: tokenizer 的路径，适用于文生图模型，留空则自动从远程下载。\n    * `--tokenizer_t5xxl_path`: T5-XXL tokenizer 的路径，适用于文生图模型，留空则自动从远程下载。\n\n我们构建了一个样例图像数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/FLUX.md",
    "content": "# FLUX\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\nFLUX 是由 Black Forest Labs 开发并开源的图像生成模型系列。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 1,\n)\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")\n```\n\n## 模型总览\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|\n|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|\n|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|\n|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|\n|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|\n|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|\n|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|\n|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|\n|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|\n|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|\n|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|\n|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py)|\n|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|\n|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py)|\n\n特殊训练脚本：\n\n* 差分 LoRA 训练：[doc](../Training/Differential_LoRA.md)\n* FP8 精度训练：[doc](../Training/FP8_Precision.md)\n* 两阶段拆分训练：[doc](../Training/Split_Training.md)\n* 端到端直接蒸馏：[doc](../Training/Direct_Distill.md)\n\n## 模型推理\n\n模型通过 `FluxImagePipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`FluxImagePipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述画面中出现的内容。\n* `negative_prompt`: 负向提示词，描述画面中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 1，当设置为大于 1 的值时启用 CFG。\n* `height`: 图像高度，需保证高度为 16 的倍数。\n* `width`: 图像宽度，需保证宽度为 16 的倍数。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 30。\n* `embedded_guidance`: 嵌入式引导参数，默认值为 3.5。\n* `t5_sequence_length`: T5 文本编码器的序列长度，默认为 512。\n* `tiled`: 是否启用 VAE 分块推理，默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用，会产生少许误差，以及少量推理时间延长。\n* `tile_size`: VAE 编解码阶段的分块大小，默认为 128，仅在 `tiled=True` 时生效。\n* `tile_stride`: VAE 编解码阶段的分块步长，默认为 64，仅在 `tiled=True` 时生效，需保证其数值小于或等于 `tile_size`。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n* `controlnet_inputs`: ControlNet 模型的输入，类型为 `ControlNetInput` 列表。\n* `ipadapter_images`: IP-Adapter 模型的输入图像列表。\n* `ipadapter_scale`: IP-Adapter 模型的引导强度。\n* `infinityou_id_image`: InfiniteYou 模型的输入图像。\n* `infinityou_guidance`: InfiniteYou 模型的引导强度。\n* `kontext_images`: Kontext 模型的输入图像。\n* `eligen_entity_prompts`: EliGen 分区控制的提示词列表。\n* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像列表。\n* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。\n* `eligen_enable_inpaint`: 是否启用 EliGen 分区控制的局部重绘功能。\n* `lora_encoder_inputs`: LoRA 编码器的输入图像列表。\n* `lora_encoder_scale`: LoRA 编码器的引导强度。\n* `step1x_reference_image`: Step1X 模型的参考图像。\n* `flex_inpaint_image`: Flex 模型的待修复图像。\n* `flex_inpaint_mask`: Flex 模型的修复遮罩。\n* `flex_control_image`: Flex 模型的控制图像。\n* `flex_control_strength`: Flex 模型的控制强度。\n* `flex_control_stop`: Flex 模型的控制停止时间步。\n* `nexus_gen_reference_image`: Nexus-Gen 模型的参考图像。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"模型总览\"中的表格。\n\n## 模型训练\n\nFLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 图像宽高配置（适用于图像生成模型和视频生成模型）\n        * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n* FLUX 专有参数\n    * `--tokenizer_1_path`: CLIP tokenizer 的路径，留空则自动从远程下载。\n    * `--tokenizer_2_path`: T5 tokenizer 的路径，留空则自动从远程下载。\n    * `--align_to_opensource_format`: 是否将 LoRA 格式对齐到开源格式，仅适用于 DiT 的 LoRA。\n\n我们构建了一个样例图像数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/FLUX2.md",
    "content": "# FLUX.2\n\nFLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。\n\n## 模型血缘\n\n```mermaid\ngraph LR;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-dev;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;\n    FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;\n```\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 10G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image.jpg\")\n```\n\n## 模型总览\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|\n|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|\n|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|\n|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|\n|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|\n\n特殊训练脚本：\n\n* 差分 LoRA 训练：[doc](../Training/Differential_LoRA.md)\n* FP8 精度训练：[doc](../Training/FP8_Precision.md)\n* 两阶段拆分训练：[doc](../Training/Split_Training.md)\n* 端到端直接蒸馏：[doc](../Training/Direct_Distill.md)\n\n## 模型推理\n\n模型通过 `Flux2ImagePipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`Flux2ImagePipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述画面中出现的内容。\n* `negative_prompt`: 负向提示词，描述画面中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 1，当设置为大于 1 的值时启用 CFG。\n* `height`: 图像高度，需保证高度为 16 的倍数。\n* `width`: 图像宽度，需保证宽度为 16 的倍数。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 30。\n* `embedded_guidance`: 嵌入式引导参数，默认值为 3.5。\n* `t5_sequence_length`: T5 文本编码器的序列长度，默认为 512。\n* `tiled`: 是否启用 VAE 分块推理，默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用，会产生少许误差，以及少量推理时间延长。\n* `tile_size`: VAE 编解码阶段的分块大小，默认为 128，仅在 `tiled=True` 时生效。\n* `tile_stride`: VAE 编解码阶段的分块步长，默认为 64，仅在 `tiled=True` 时生效，需保证其数值小于或等于 `tile_size`。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"模型总览\"中的表格。\n\n## 模型训练\n\nFLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 图像宽高配置（适用于图像生成模型和视频生成模型）\n        * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n* FLUX.2 专有参数\n    * `--tokenizer_path`: tokenizer 的路径，适用于文生图模型，留空则自动从远程下载。\n\n我们构建了一个样例图像数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/LTX-2.md",
    "content": "# LTX-2\n\nLTX-2 是由 Lightricks 开发的音视频生成模型系列。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8GB 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\n\nprompt = \"A girl is very happy, she is speaking: \\\"I enjoy working with Diffsynth-Studio, it's a perfect framework.\\\"\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n```\n\n## 模型总览\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|\n|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|\n|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|\n|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|\n|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|\n\n## 模型推理\n\n模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`LTX2AudioVideoPipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述视频中出现的内容。\n* `negative_prompt`: 负向提示词，描述视频中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 3.0。\n* `input_images`: 输入图像列表，用于图生视频。\n* `input_images_indexes`: 输入图像在视频中的帧索引列表。\n* `input_images_strength`: 输入图像的强度，默认值为 1.0。\n* `denoising_strength`: 去噪强度，范围是 0～1，默认值为 1.0。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `height`: 视频高度，需保证高度为 32 的倍数（单阶段）或 64 的倍数（两阶段）。\n* `width`: 视频宽度，需保证宽度为 32 的倍数（单阶段）或 64 的倍数（两阶段）。\n* `num_frames`: 视频帧数，默认值为 121，需保证为 8 的倍数 + 1。\n* `num_inference_steps`: 推理次数，默认值为 40。\n* `tiled`: 是否启用 VAE 分块推理，默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用，会产生少许误差，以及少量推理时间延长。\n* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小，默认为 512。\n* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小，默认为 128。\n* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小，默认为 128。\n* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小，默认为 24。\n* `use_two_stage_pipeline`: 是否使用两阶段管道，默认为 `False`。\n* `use_distilled_pipeline`: 是否使用蒸馏管道，默认为 `False`。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"支持的推理脚本\"中的表格。\n\n## 模型训练\n\nLTX-2 系列模型统一通过 [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练图像编辑模型时需要额外参数，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 视频宽高配置\n        * `--height`: 视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的视频帧都会被缩小，分辨率小于这个数值的视频帧保持不变。\n        * `--num_frames`: 视频的帧数。\n* LTX-2 系列特定参数\n    * `--tokenizer_path`: 分词器路径，适用于文生视频模型，留空则从远程自动下载。\n    * `--frame_rate`: 训练视频的帧率。\n\n我们构建了一个样例视频数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/Overview.md",
    "content": "# 模型目录\n\n## Qwen-Image\n\n文档：[./Qwen-Image.md](../Model_Details/Qwen-Image.md)\n\n<details>\n\n<summary>效果一览</summary>\n\n![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)\n\n</details>\n\n<details>\n\n<summary>快速开始</summary>\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0, num_inference_steps=40,\n    # edit_image=Image.open(\"xxx.jpg\").resize((1328, 1328)) # For Qwen-Image-Edit\n)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|\n|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|\n|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|\n|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|\n|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|\n\n## FLUX 系列\n\n文档：[./FLUX.md](../Model_Details/FLUX.md)\n\n<details>\n\n<summary>效果一览</summary>\n\n![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)\n\n</details>\n\n<details>\n\n<summary>快速开始</summary>\n\n```python\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nimage = pipe(prompt=\"a cat\", seed=0)\nimage.save(\"image.jpg\")\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;\n    FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;\n    black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;\n    FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;\n    FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;\n    FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;\n    black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;\n    black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;\n    black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;\n    black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;\n    Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;\n    black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;\n    Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;\n```\n\n</details>\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|\n|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|\n|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|\n|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|\n|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|\n|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|\n|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|\n|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|\n|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|\n|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|\n|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|\n|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py)|\n|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|\n|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py)|\n\n## Wan 系列\n\n文档：[./Wan.md](../Model_Details/Wan.md)\n\n<details>\n\n<summary>效果一览</summary>\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\n</details>\n\n<details>\n\n<summary>快速开始</summary>\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video1.mp4\", fps=15, quality=5)\n```\n\n</details>\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|\n|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|\n|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|\n|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|\n|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|\n|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|\n|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|\n|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|\n|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|\n|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|\n|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|\n|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|\n|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|\n|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|\n|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|\n|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|\n|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|\n|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|\n|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|\n|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|\n|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|\n|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|\n|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|\n"
  },
  {
    "path": "docs/zh/Model_Details/Qwen-Image.md",
    "content": "# Qwen-Image\n\n![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)\n\nQwen-Image 是由阿里巴巴通义实验室通义千问团队训练并开源的图像生成模型。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## 模型总览\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;\n    Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;\n    Qwen/Qwen-Image-->EliGen-Series;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;\n    DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;\n    EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;\n    Qwen/Qwen-Image-->Distill-Series;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;\n    Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;\n    Qwen/Qwen-Image-->ControlNet-Series;\n    ControlNet-Series-->Blockwise-ControlNet-Series;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;\n    Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;\n    ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;\n    Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;\n```\n\n</details>\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|\n|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|\n|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|\n|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|\n|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|\n|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|\n|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|\n|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|\n|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|\n|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|\n|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|\n|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|\n|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|\n|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|\n|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|\n\n特殊训练脚本：\n\n* 差分 LoRA 训练：[doc](../Training/Differential_LoRA.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/)\n* FP8 精度训练：[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/)\n* 两阶段拆分训练：[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/)\n* 端到端直接蒸馏：[doc](../Training/Direct_Distill.md)、[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)\n\nDeepSpeed ZeRO 3 训练：Qwen-Image 系列模型支持 DeepSpeed ZeRO 3 训练，将模型拆分到多个 GPU 上，以 Qwen-Image 模型的全量训练为例，需修改：\n\n* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`\n* `--initialize_model_on_cpu`\n\n## 模型推理\n\n模型通过 `QwenImagePipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`QwenImagePipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述画面中出现的内容。\n* `negative_prompt`: 负向提示词，描述画面中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 4，当设置为 1 时不再生效。\n* `input_image`: 输入图像，用于图生图，该参数与 `denoising_strength` 配合使用。\n* `denoising_strength`: 去噪强度，范围是 0～1，默认值为 1，当数值接近 0 时，生成图像与输入图像相似；当数值接近 1 时，生成图像与输入图像相差更大。在不输入 `input_image` 参数时，请不要将其设置为非 1 的数值。\n* `inpaint_mask`: 图像局部重绘的遮罩图像。\n* `inpaint_blur_size`: 图像局部重绘的边缘柔化宽度。\n* `inpaint_blur_sigma`: 图像局部重绘的边缘柔化强度。\n* `height`: 图像高度，需保证高度为 16 的倍数。\n* `width`: 图像宽度，需保证宽度为 16 的倍数。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 30。\n* `exponential_shift_mu`: 在采样时间步时采用的固定参数，留空则根据图像宽高进行采样。\n* `blockwise_controlnet_inputs`: Blockwise ControlNet 模型的输入。\n* `eligen_entity_prompts`: EliGen 分区控制的提示词。\n* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像。\n* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。\n* `edit_image`: 编辑模型的待编辑图像，支持多张图像。\n* `edit_image_auto_resize`: 是否自动缩放待编辑图像。\n* `edit_rope_interpolation`: 是否在低分辨率编辑图像上启用 ROPE 插值。\n* `context_image`: In-Context Control 的输入图像。\n* `tiled`: 是否启用 VAE 分块推理，默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用，会产生少许误差，以及少量推理时间延长。\n* `tile_size`: VAE 编解码阶段的分块大小，默认为 128，仅在 `tiled=True` 时生效。\n* `tile_stride`: VAE 编解码阶段的分块步长，默认为 64，仅在 `tiled=True` 时生效，需保证其数值小于或等于 `tile_size`。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文“模型总览”中的表格。\n\n## 模型训练\n\nQwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 图像宽高配置（适用于图像生成模型和视频生成模型）\n        * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n* Qwen-Image 专有参数\n    * `--tokenizer_path`: tokenizer 的路径，适用于文生图模型，留空则自动从远程下载。\n    * `--processor_path`: processor 的路径，适用于图像编辑模型，留空则自动从远程下载。\n\n我们构建了一个样例图像数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文“模型总览”中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/Wan.md",
    "content": "# Wan\n\nhttps://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314\n\nWan 是由阿里巴巴通义实验室通义万相团队开发的视频生成模型系列。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动，框架会自动根据剩余显存控制模型参数的加载，最低 8G 显存即可运行。\n\n```python\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n## 模型总览\n\n<details>\n\n<summary>模型血缘</summary>\n\n```mermaid\ngraph LR;\n    Wan-Series-->Wan2.1-Series;\n    Wan-Series-->Wan2.2-Series;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;\n    Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;\n    Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;\n    Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;\n    iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;\n    Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;\n    Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;\n    Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;\n    Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;\n    Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;\n    Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;\n    Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;\n    Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;\n    Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;\n    Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;\n    Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;\n    Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;\n    Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;\n```\n\n</details>\n\n|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|-|\n|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|\n|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|\n|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|\n|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|\n|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|\n|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|\n|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|\n|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|\n|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|\n|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|\n|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|\n|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|\n|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|\n|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|\n|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|\n|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|\n|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|\n|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|\n|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|\n|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|\n|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|\n|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|\n|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|\n|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|\n|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|\n|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|\n|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|\n|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|\n|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|\n\n* FP8 精度训练：[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)\n* 两阶段拆分训练：[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)\n* 端到端直接蒸馏：[doc](../Training/Direct_Distill.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)\n\nDeepSpeed ZeRO 3 训练：Wan 系列模型支持 DeepSpeed ZeRO 3 训练，将模型拆分到多个 GPU 上，以 Wan2.1-T2V-14B 模型的全量训练为例，需修改：\n\n* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml`\n* `--initialize_model_on_cpu`\n\n## 模型推理\n\n模型通过 `WanVideoPipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`WanVideoPipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述视频中出现的内容。\n* `negative_prompt`: 负向提示词，描述视频中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 5，当设置为 1 时不再生效。\n* `input_image`: 输入图像，用于图生视频，该参数与 `denoising_strength` 配合使用。\n* `end_image`: 结束图像，用于首尾帧生成视频。\n* `input_video`: 输入视频，用于视频到视频生成，该参数与 `denoising_strength` 配合使用。\n* `denoising_strength`: 去噪强度，范围是 0～1，默认值为 1，当数值接近 0 时，生成视频与输入视频相似；当数值接近 1 时，生成视频与输入视频相差更大。\n* `control_video`: 控制视频，用于控制视频生成过程。\n* `reference_image`: 参考图像，用于保持生成视频中某些特征的一致性。\n* `camera_control_direction`: 相机控制方向，可选值为 `\"Left\"`, `\"Right\"`, `\"Up\"`, `\"Down\"`, `\"LeftUp\"`, `\"LeftDown\"`, `\"RightUp\"`, `\"RightDown\"`。\n* `camera_control_speed`: 相机控制速度，默认值为 1/54。\n* `vace_video`: VACE 控制视频。\n* `vace_video_mask`: VACE 控制视频遮罩。\n* `vace_reference_image`: VACE 参考图像。\n* `vace_scale`: VACE 控制强度，默认值为 1.0。\n* `animate_pose_video`: `animate` 模型姿态视频。\n* `animate_face_video`: `animate` 模型面部视频。\n* `animate_inpaint_video`: `animate` 模型局部编辑视频。\n* `animate_mask_video`: `animate` 模型遮罩视频。\n* `vap_video`: `video-as-prompt` 的输入视频。\n* `vap_prompt`: `video-as-prompt` 的文本描述。\n* `negative_vap_prompt`: `video-as-prompt` 的负向文本描述。\n* `input_audio`: 输入音频，用于语音到视频生成。\n* `audio_embeds`: 音频嵌入向量。\n* `audio_sample_rate`: 音频采样率，默认值为 16000。\n* `s2v_pose_video`: S2V 模型的姿态视频。\n* `motion_video`: S2V 模型的运动视频。\n* `height`: 视频高度，需保证高度为 16 的倍数。\n* `width`: 视频宽度，需保证宽度为 16 的倍数。\n* `num_frames`: 视频帧数，默认值为 81，需保证为 4 的倍数 + 1。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 50。\n* `motion_bucket_id`: 运动控制参数，数值越大，运动幅度越大。\n* `longcat_video`: LongCat 输入视频。\n* `tiled`: 是否启用 VAE 分块推理，默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用，会产生少许误差，以及少量推理时间延长。\n* `tile_size`: VAE 编解码阶段的分块大小，默认为 `(30, 52)`，仅在 `tiled=True` 时生效。\n* `tile_stride`: VAE 编解码阶段的分块步长，默认为 `(15, 26)`，仅在 `tiled=True` 时生效，需保证其数值小于或等于 `tile_size`。\n* `switch_DiT_boundary`: 切换DiT模型的时间边界，默认值为 0.875。\n* `sigma_shift`: 时间步偏移参数，默认值为 5.0。\n* `sliding_window_size`: 滑动窗口大小。\n* `sliding_window_stride`: 滑动窗口步长。\n* `tea_cache_l1_thresh`: TeaCache 的 L1 阈值。\n* `tea_cache_model_id`: TeaCache 使用的模型 ID。\n* `progress_bar_cmd`: 进度条，默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"模型总览\"中的表格。\n\n### 多卡并行加速\n\n如需开启多卡并行加速，请先安装 `flash_attn` 与 `xfuser`：\n\n```shell\npip install flash-attn --no-build-isolation\npip install xfuser\n```\n\n对代码进行如下修改（[样例代码](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)）：\n\n```diff\nimport torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n+ import torch.distributed as dist\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n+   use_usp=True,\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\nvideo = pipe(\n    prompt=\"一名宇航员身穿太空服，面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方，点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健，扬起微弱的尘埃，展现出未来科技与原始探索的完美结合。宇航员手持操控装置，目光坚定，仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球，画面既科幻又充满希望，让人不禁畅想未来的星际生活。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\n+ if dist.get_rank() == 0:\n+   save_video(video, \"video1.mp4\", fps=15, quality=5)\n```\n\n运行多卡并行推理时，请使用 `torchrun` 运行，其中 `--nproc_per_node` 为 GPU 数量：\n\n```shell\ntorchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py\n```\n\n## 模型训练\n\nWan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练图像编辑模型时需要额外参数，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 视频宽高配置\n        * `--height`: 视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的视频帧都会被缩小，分辨率小于这个数值的视频帧保持不变。\n        * `--num_frames`: 视频的帧数。\n* Wan 系列专有参数\n    * `--tokenizer_path`: tokenizer 的路径，适用于文生视频模型，留空则自动从远程下载。\n    * `--audio_processor_path`: 音频处理器的路径，适用于语音到视频模型，留空则自动从远程下载。\n\n我们构建了一个样例视频数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n"
  },
  {
    "path": "docs/zh/Model_Details/Z-Image.md",
    "content": "# Z-Image\n\nZ-Image 是由阿里巴巴通义实验室多模态交互团队训练并开源的图像生成模型。\n\n## 安装\n\n在使用本项目进行模型推理和训练前，请先安装 DiffSynth-Studio。\n\n```shell\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n更多关于安装的信息，请参考[安装依赖](../Pipeline_Usage/Setup.md)。\n\n## 快速开始\n\n运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化，因此不建议在 Z-Image Turbo 模型上开启任何量化，仅建议开启 CPU Offload，最低 8G 显存即可运行。\n\n```python\nfrom diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n```\n\n## 模型总览\n\n|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|\n|-|-|-|-|-|-|-|\n|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)|\n|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|\n|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|\n|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|\n\n特殊训练脚本：\n\n* 差分 LoRA 训练：[doc](../Training/Differential_LoRA.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)\n* 轨迹模仿蒸馏训练（实验性功能）：[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)\n\n## 模型推理\n\n模型通过 `ZImagePipeline.from_pretrained` 加载，详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。\n\n`ZImagePipeline` 推理的输入参数包括：\n\n* `prompt`: 提示词，描述画面中出现的内容。\n* `negative_prompt`: 负向提示词，描述画面中不应该出现的内容，默认值为 `\"\"`。\n* `cfg_scale`: Classifier-free guidance 的参数，默认值为 1。\n* `input_image`: 输入图像，用于图生图，该参数与 `denoising_strength` 配合使用。\n* `denoising_strength`: 去噪强度，范围是 0～1，默认值为 1，当数值接近 0 时，生成图像与输入图像相似；当数值接近 1 时，生成图像与输入图像相差更大。在不输入 `input_image` 参数时，请不要将其设置为非 1 的数值。\n* `height`: 图像高度，需保证高度为 16 的倍数。\n* `width`: 图像宽度，需保证宽度为 16 的倍数。\n* `seed`: 随机种子。默认为 `None`，即完全随机。\n* `rand_device`: 生成随机高斯噪声矩阵的计算设备，默认为 `\"cpu\"`。当设置为 `cuda` 时，在不同 GPU 上会导致不同的生成结果。\n* `num_inference_steps`: 推理次数，默认值为 8。\n* `controlnet_inputs`: ControlNet 模型的输入。\n* `edit_image`: 编辑模型的待编辑图像，支持多张图像。\n* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。\n\n如果显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)，我们在示例代码中提供了每个模型推荐的低显存配置，详见前文\"模型总览\"中的表格。\n\n## 模型训练\n\nZ-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py) 进行训练，脚本的参数包括：\n\n* 通用训练参数\n    * 数据集基础配置\n        * `--dataset_base_path`: 数据集的根目录。\n        * `--dataset_metadata_path`: 数据集的元数据文件路径。\n        * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n        * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n        * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n    * 模型加载配置\n        * `--model_paths`: 要加载的模型路径。JSON 格式。\n        * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors\"`。用逗号分隔。\n        * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练图像编辑模型时需要额外参数，以 `,` 分隔。\n        * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n    * 训练基础配置\n        * `--learning_rate`: 学习率。\n        * `--num_epochs`: 轮数（Epoch）。\n        * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n        * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n        * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n        * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n    * 输出配置\n        * `--output_path`: 模型保存路径。\n        * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n        * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n    * LoRA 配置\n        * `--lora_base_model`: LoRA 添加到哪个模型上。\n        * `--lora_target_modules`: LoRA 添加到哪些层上。\n        * `--lora_rank`: LoRA 的秩（Rank）。\n        * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n        * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n        * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n    * 梯度配置\n        * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n        * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n        * `--gradient_accumulation_steps`: 梯度累积步数。\n    * 图像宽高配置（适用于图像生成模型和视频生成模型）\n        * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n        * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n* Z-Image 专有参数\n    * `--tokenizer_path`: tokenizer 的路径，适用于文生图模型，留空则自动从远程下载。\n\n我们构建了一个样例图像数据集，以方便您进行测试，通过以下命令可以下载这个数据集：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n我们为每个模型编写了推荐的训练脚本，请参考前文\"模型总览\"中的表格。关于如何编写模型训练脚本，请参考[模型训练](../Pipeline_Usage/Model_Training.md)；更多高阶训练算法，请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。\n\n训练技巧：\n\n* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 是一个蒸馏加速的模型，因此直接训练将会迅速让模型失去加速能力，以“加速配置”（`num_inference_steps=8`，`cfg_scale=1`）推理的效果变差，以“无加速配置”（`num_inference_steps=30`，`cfg_scale=2`）推理的效果变好。可采用以下方案训练和推理：\n    * 标准 SFT 训练（[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)） + 无加速配置推理\n    * 差分 LoRA 训练（[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)） + 加速配置推理\n        * 差分 LoRA 训练中需加载一个额外的 LoRA，例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)\n    * 标准 SFT 训练（[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)）+ 轨迹模仿蒸馏训练（[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)）+ 加速配置推理\n    * 标准 SFT 训练（[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)）+ 推理时加载蒸馏加速 LoRA（[model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)） + 加速配置推理\n"
  },
  {
    "path": "docs/zh/Pipeline_Usage/Environment_Variables.md",
    "content": "# 环境变量\n\n`DiffSynth-Studio` 可通过环境变量控制一些设置。\n\n在 `Python` 代码中，可以使用 `os.environ` 设置环境变量。请注意，环境变量需在 `import diffsynth` 前设置。\n\n```python\nimport os\nos.environ[\"DIFFSYNTH_MODEL_BASE_PATH\"] = \"./path_to_my_models\"\nimport diffsynth\n```\n\n在 Linux 操作系统上，也可在命令行临时设置环境变量：\n\n```shell\nDIFFSYNTH_MODEL_BASE_PATH=\"./path_to_my_models\" python xxx.py\n```\n\n以下是 `DiffSynth-Studio` 所支持的环境变量。\n\n## `DIFFSYNTH_SKIP_DOWNLOAD`\n\n是否跳过模型下载。可设置为 `True`、`true`、`False`、`false`，若 `ModelConfig` 中没有设置 `skip_download`，则会根据这一环境变量决定是否跳过模型下载。\n\n## `DIFFSYNTH_MODEL_BASE_PATH`\n\n模型下载根目录。可设置为任意本地路径，若 `ModelConfig` 中没有设置 `local_model_path`，则会将模型文件下载到这一环境变量指向的路径。若两者都未设置，则会将模型文件下载到 `./models`。\n\n## `DIFFSYNTH_ATTENTION_IMPLEMENTATION`\n\n注意力机制实现的方式，可以设置为 `flash_attention_3`、`flash_attention_2`、`sage_attention`、`xformers`、`torch`。详见 [`./core/attention.md`](../API_Reference/core/attention.md).\n\n## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE`\n\n硬盘直连中的 Buffer 大小，默认是 1B（1000000000），数值越大，占用内存越大，速度越快。\n\n## `DIFFSYNTH_DOWNLOAD_SOURCE`\n\n远程模型下载源，可设置为 `modelscope` 或 `huggingface`，控制模型下载的来源，默认值为 `modelscope`。\n"
  },
  {
    "path": "docs/zh/Pipeline_Usage/GPU_support.md",
    "content": "# GPU/NPU 支持\n\n`DiffSynth-Studio` 支持多种 GPU/NPU，本文介绍如何在这些设备上运行模型推理和训练。\n\n在开始前，请参考[安装依赖](../Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。\n\n## NVIDIA GPU\n\n本项目提供的所有样例代码默认支持 NVIDIA GPU，无需额外修改。\n\n## AMD GPU\n\nAMD 提供了基于 ROCm 的 torch 包，所以大多数模型无需修改代码即可运行，少数模型由于依赖特定的 cuda 指令无法运行。\n\n## Ascend NPU\n### 推理\n使用 Ascend NPU 时，需把代码中的 `\"cuda\"` 改为 `\"npu\"`。\n\n例如，Wan2.1-T2V-1.3B 的推理代码：\n\n```diff\nimport torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom diffsynth.core.device.npu_compatible_device import get_device_name\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n-   \"preparing_device\": \"cuda\",\n+   \"preparing_device\": \"npu\",\n    \"computation_dtype\": torch.bfloat16,\n-   \"computation_device\": \"cuda\",\n+   \"computation_device\": \"npu\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n-   device=\"cuda\",\n+   device=\"npu\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n-   vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n+   vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video.mp4\", fps=15, quality=5)\n```\n\n#### USP(Unified Sequence Parallel)\n如果想要在NPU上使用该特性，请通过如下方式安装额外的第三方库：\n```shell\npip install git+https://github.com/feifeibear/long-context-attention.git\npip install git+https://github.com/xdit-project/xDiT.git\n```\n\n### 训练\n当前已为每类模型添加NPU的启动脚本样例，脚本存放在`examples/xxx/special/npu_training`目录下，例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。\n\n在NPU训练脚本中，添加了可以优化性能的NPU特有环境变量，并针对特定模型开启了相关参数。\n\n#### 环境变量\n```shell\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n```\n`expandable_segments:<value>`: 使能内存池扩展段功能，即虚拟内存特征。\n\n```shell\nexport CPU_AFFINITY_CONF=1\n```\n设置0或未设置: 表示不启用绑核功能\n\n1: 表示开启粗粒度绑核\n\n2: 表示开启细粒度绑核\n\n#### 特定模型需要开启的参数\n| 模型        | 参数 | 备注                |\n|-----------|------|-------------------|\n| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |\n| Qwen-Image系列 | --initialize_model_on_cpu | 模型需要在cpu上进行初始化 |\n| Z-Image 系列 | --enable_npu_patch | 使用NPU融合算子来替换Z-image模型中的对应算子以提升模型在NPU上的性能 |"
  },
  {
    "path": "docs/zh/Pipeline_Usage/Model_Inference.md",
    "content": "# 模型推理\n\n本文档以 Qwen-Image 模型为例，介绍如何使用 `DiffSynth-Studio` 进行模型推理。\n\n## 加载模型\n\n模型通过 `from_pretrained` 加载：\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n```\n\n其中 `torch_dtype` 和 `device` 是计算精度和计算设备（不是模型的精度和设备）。`model_configs` 可通过多种方式配置模型路径，关于本项目内部是如何加载模型的，请参考 [`diffsynth.core.loader`](../API_Reference/core/loader.md)。\n\n<details>\n\n<summary>从远程下载模型并加载</summary>\n\n> `DiffSynth-Studio` 默认从[魔搭社区](https://www.modelscope.cn/)下载并加载模型，需填写 `model_id` 和 `origin_file_pattern`，例如\n> \n> ```python\n> ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n> ```\n> \n> 模型文件默认下载到 `./models` 路径，该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。\n\n</details>\n\n<details>\n\n<summary>从本地文件路径加载模型</summary>\n\n> 填写 `path`，例如\n> \n> ```python\n> ModelConfig(path=\"models/xxx.safetensors\")\n> ```\n> \n> 对于从多个文件加载的模型，使用列表即可，例如\n> \n> ```python\n> ModelConfig(path=[\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>     \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\",\n> ])\n> ```\n\n</details>\n\n默认情况下，即使模型已经下载完毕，程序仍会向远程查询是否有遗漏文件，如果要完全关闭远程请求，请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。\n\n```shell\nimport os\nos.environ[\"DIFFSYNTH_SKIP_DOWNLOAD\"] = \"True\"\nimport diffsynth\n```\n\n如需从 [HuggingFace](https://huggingface.co/) 下载模型，请将[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](../Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) 设置为 `huggingface`。\n\n```shell\nimport os\nos.environ[\"DIFFSYNTH_DOWNLOAD_SOURCE\"] = \"huggingface\"\nimport diffsynth\n```\n\n## 启动推理\n\n输入提示词，即可启动推理过程，生成一张图片。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n每个模型 `Pipeline` 的输入参数不同，请参考各模型的文档。\n\n如果模型参数量太大，导致显存不足，请开启[显存管理](../Pipeline_Usage/VRAM_management.md)。\n\n## 加载 LoRA\n\nLoRA 是一种轻量化的模型训练方式，产生少量参数，扩展模型的能力。DiffSynth-Studio 的 LoRA 加载有两种方式：冷加载和热加载。\n\n* 冷加载：当基础模型未开启[显存管理](../Pipeline_Usage/VRAM_management.md)时，LoRA 会融合进基础模型权重，此时推理速度没有变化，LoRA 加载后无法卸载。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nlora = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1\", origin_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, lora, alpha=1)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n* 热加载：当基础模型开启[显存管理](../Pipeline_Usage/VRAM_management.md)时，LoRA 不会融合进基础模型权重，此时推理速度会变慢，LoRA 加载后可通过 `pipe.clear_lora()` 卸载。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cuda\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nlora = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1\", origin_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, lora, alpha=1)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\npipe.clear_lora()\n```\n"
  },
  {
    "path": "docs/zh/Pipeline_Usage/Model_Training.md",
    "content": "# 模型训练\n\n本文档介绍如何使用 `DiffSynth-Studio` 进行模型训练。\n\n## 脚本参数\n\n训练脚本通常包含以下参数：\n\n* 数据集基础配置\n    * `--dataset_base_path`: 数据集的根目录。\n    * `--dataset_metadata_path`: 数据集的元数据文件路径。\n    * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。\n    * `--dataset_num_workers`: 每个 Dataloder 的进程数量。\n    * `--data_file_keys`: 元数据中需要加载的字段名称，通常是图像或视频文件的路径，以 `,` 分隔。\n* 模型加载配置\n    * `--model_paths`: 要加载的模型路径。JSON 格式。\n    * `--model_id_with_origin_paths`: 带原始路径的模型 ID，例如 `\"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\"`。用逗号分隔。\n    * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数，例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`，以 `,` 分隔。\n    * `--fp8_models`：以 FP8 格式加载的模型，格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致，目前仅支持参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）。\n* 训练基础配置\n    * `--learning_rate`: 学习率。\n    * `--num_epochs`: 轮数（Epoch）。\n    * `--trainable_models`: 可训练的模型，例如 `dit`、`vae`、`text_encoder`。\n    * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数，少数模型包含不参与梯度计算的冗余参数，需开启这一设置避免在多 GPU 训练中报错。\n    * `--weight_decay`：权重衰减大小，详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。\n    * `--task`: 训练任务，默认为 `sft`，部分模型支持更多训练模式，请参考每个特定模型的文档。\n* 输出配置\n    * `--output_path`: 模型保存路径。\n    * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。\n    * `--save_steps`: 保存模型的训练步数间隔，若此参数留空，则每个 epoch 保存一次。\n* LoRA 配置\n    * `--lora_base_model`: LoRA 添加到哪个模型上。\n    * `--lora_target_modules`: LoRA 添加到哪些层上。\n    * `--lora_rank`: LoRA 的秩（Rank）。\n    * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径，LoRA 将从此检查点加载。\n    * `--preset_lora_path`: 预置 LoRA 检查点路径，如果提供此路径，这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。\n    * `--preset_lora_model`: 预置 LoRA 融入的模型，例如 `dit`。\n* 梯度配置\n    * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。\n    * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。\n    * `--gradient_accumulation_steps`: 梯度累积步数。\n* 图像宽高配置（适用于图像生成模型和视频生成模型）\n    * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。\n    * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。\n    * `--max_pixels`: 图像或视频帧的最大像素面积，当启用动态分辨率时，分辨率大于这个数值的图片都会被缩小，分辨率小于这个数值的图片保持不变。\n\n部分模型的训练脚本还包含额外的参数，详见各模型的文档。\n\n## 准备数据集\n\n`DiffSynth-Studio` 采用通用数据集格式，数据集包含一系列数据文件（图像、视频等），以及标注元数据的文件，我们建议您这样组织数据集文件：\n\n```\ndata/example_image_dataset/\n├── metadata.csv\n├── image_1.jpg\n└── image_2.jpg\n```\n\n其中 `image_1.jpg`、`image_2.jpg` 为训练用图像数据，`metadata.csv` 为元数据列表，例如\n\n```\nimage,prompt\nimage_1.jpg,\"a dog\"\nimage_2.jpg,\"a cat\"\n```\n\n我们构建了样例数据集，以方便您进行测试。了解通用数据集架构是如何实现的，请参考 [`diffsynth.core.data`](../API_Reference/core/data.md)。\n\n<details>\n\n<summary>样例数据集</summary>\n\n> ```shell\n> modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n> ```\n> \n> 适用于 Qwen-Image、FLUX 等图像生成模型的训练。\n\n</details>\n\n## 加载模型\n\n类似于[推理时的模型加载](../Pipeline_Usage/Model_Inference.md#加载模型)，我们支持多种方式配置模型路径，两种方式是可以混用的。\n\n<details>\n\n<summary>从远程下载模型并加载</summary>\n\n> 如果在推理时我们通过以下设置加载模型\n> \n> ```python\n> model_configs=[\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n>     ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n> ]\n> ```\n> \n> 那么在训练时，填入以下参数即可加载对应的模型。\n> \n> ```shell\n> --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\"\n> ```\n> \n> 模型文件默认下载到 `./models` 路径，该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。\n> \n> 默认情况下，即使模型已经下载完毕，程序仍会向远程查询是否有遗漏文件，如果要完全关闭远程请求，请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。\n\n</details>\n\n<details>\n\n<summary>从本地文件路径加载模型</summary>\n\n> 如果从本地文件加载模型，例如推理时\n> \n> ```python\n> model_configs=[\n>     ModelConfig([\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors\"\n>     ]),\n>     ModelConfig([\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n>     ]),\n>     ModelConfig(\"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors\")\n> ]\n> ```\n> \n> 那么训练时需设置为\n> \n> ```shell\n> --model_paths '[\n>     [\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors\",\n>         \"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors\"\n>     ],\n>     [\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors\",\n>         \"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors\"\n>     ],\n>     \"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors\"\n> ]' \\\n> ```\n> \n> 请注意，`--model_paths` 是 json 格式，其中不能出现多余的 `,`，否则无法被正常解析。\n\n</details>\n\n## 设置可训练模块\n\n训练框架支持任意模型的训练，以 Qwen-Image 为例，若全量训练其中的 DiT 模型，则需设置为\n\n```shell\n--trainable_models \"dit\"\n```\n\n若训练 DiT 模型的 LoRA，则需设置\n\n```shell\n--lora_base_model dit --lora_target_modules \"to_q,to_k,to_v\" --lora_rank 32\n```\n\n我们希望给技术探索留下足够的发挥空间，因此框架支持同时训练任意多个模块，例如同时训练 text encoder、controlnet，以及 DiT 的 LoRA：\n\n```shell\n--trainable_models \"text_encoder,controlnet\" --lora_base_model dit --lora_target_modules \"to_q,to_k,to_v\" --lora_rank 32\n```\n\n此外，由于训练脚本中加载了多个模块（text encoder、dit、vae 等），保存模型文件时需要移除前缀，例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时，请设置 `--remove_prefix_in_ckpt pipe.dit.`。如果多个模块同时训练，则需开发者在训练完成后自行编写代码拆分模型文件中的 state dict。\n\n## 启动训练程序\n\n训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建，训练命令按照如下格式编写：\n\n```shell\naccelerate launch xxx/train.py \\\n  --xxx yyy \\\n  --xxxx yyyy\n```\n\n我们为每个模型编写了预置的训练脚本，详见各模型的文档。\n\n默认情况下，`accelerate` 会按照 `~/.cache/huggingface/accelerate/default_config.yaml` 的配置进行训练，使用 `accelerate config` 可在终端交互式地配置，包括多 GPU 训练、[`DeepSpeed`](https://www.deepspeed.ai/) 等。\n\n我们为部分模型提供了推荐的 `accelerate` 配置文件，可通过 `--config_file` 设置，例如 Qwen-Image 模型的全量训练：\n\n```shell\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n```\n\n## 训练注意事项\n\n* 数据集的元数据除 `csv` 格式外，还支持 `json`、`jsonl` 格式，关于如何选择最佳的元数据格式，请参考[](../API_Reference/core/data.md#元数据)\n* 通常训练效果与训练步数强相关，与 epoch 数量弱相关，因此我们更推荐使用参数 `--save_steps` 按训练步数间隔来保存模型文件。\n* 当数据量 * `dataset_repeat` 超过 $10^9$ 时，我们观测到数据集的速度明显变慢，这似乎是 `PyTorch` 的 bug，我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。\n* 学习率 `--learning_rate` 在 LoRA 训练中建议设置为 `1e-4`，在全量训练中建议设置为 `1e-5`。\n* 训练框架不支持 batch size > 1，原因是复杂的，详见 [Q&A: 为什么训练框架不支持 batch size > 1？](../QA.md#为什么训练框架不支持-batch-size--1)\n* 少数模型包含冗余参数，例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分，在训练这些模型时，需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑，我们不打算删除这些冗余参数。\n* Diffusion 模型的损失函数值与实际效果的关系不大，因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值，边训边测，直至效果收敛后手动关闭训练程序。\n* `--use_gradient_checkpointing` 通常是开启的，除非 GPU 显存足够；`--use_gradient_checkpointing_offload` 则按需开启，详见 [`diffsynth.core.gradient`](../API_Reference/core/gradient.md)。\n\n## 低显存训练\n如果想在低显存显卡上完成 LoRA 模型训练，可以同时采用 [两阶段拆分训练](../Training/Split_Training.md) 和 `deepspeed_zero3_offload` 训练。 首先，将前处理过程拆分到第一阶段，将计算结果存储到硬盘中。其次，在第二阶段从硬盘中读取这些结果并进行去噪模型的训练，训练通过采用 `deepspeed_zero3_offload`，将训练参数和优化器状态 offload 到 cpu 或者 disk 上。我们为部分模型提供了样例，主要是通过 `--config_file` 指定 `deepspeed` 配置。\n\n需要注意的是，`deepspeed_zero3_offload` 模式与 `pytorch` 原生的梯度检查点机制不兼容，我们为此对 `deepspeed` 的`checkpointing` 接口做了适配。用户需要在 `deepspeed` 配置中填写 `activation_checkpointing` 字段以启用梯度检查点。\n\n以下为 Qwen-Image 模型的低显存模型训练脚本：\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:data_process\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n\naccelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:train\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --initialize_model_on_cpu\n```\n\n其中，`accelerate` 和 `deepspeed` 的配置文件如下：\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: true\ndeepspeed_config:\n  deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json\n  zero3_init_flag: true\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\n```json\n{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": false,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 5e7,\n        \"stage3_prefetch_bucket_size\": 5e7,\n        \"stage3_param_persistence_threshold\": 1e5,\n        \"stage3_max_live_parameters\": 1e8,\n        \"stage3_max_reuse_distance\": 1e8,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    },\n    \"activation_checkpointing\": {\n        \"partition_activations\": false,\n        \"cpu_checkpointing\": false,\n        \"contiguous_memory_optimization\": false\n    },\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}\n```"
  },
  {
    "path": "docs/zh/Pipeline_Usage/Setup.md",
    "content": "# 安装依赖\n\n从源码安装（推荐）：\n\n```\ngit clone https://github.com/modelscope/DiffSynth-Studio.git\ncd DiffSynth-Studio\npip install -e .\n```\n\n从 pypi 安装（存在版本更新延迟，如需使用最新功能，请从源码安装）\n\n```\npip install diffsynth\n```\n\n## GPU/NPU 支持\n\n* NVIDIA GPU\n\n按照以上方式安装即可。\n\n* AMD GPU\n\n需安装支持 ROCm 的 `torch` 包，以 ROCm 6.4（本文更新于 2025 年 12 月 15 日）、Linux 系统为例，请运行以下命令\n\n```shell\npip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4\n```\n\n* Ascend NPU\n\n1. 通过官方文档安装[CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit)\n\n2. 从源码安装\n   ```shell\n   git clone https://github.com/modelscope/DiffSynth-Studio.git\n   cd DiffSynth-Studio\n   # aarch64/ARM\n   pip install -e .[npu_aarch64] \n   # x86\n   pip install -e .[npu] --extra-index-url \"https://download.pytorch.org/whl/cpu\"\n\n使用 Ascend NPU 时，请将 Python 代码中的 `\"cuda\"` 改为 `\"npu\"`，详见[NPU 支持](../Pipeline_Usage/GPU_support.md#ascend-npu)。\n\n## 其他安装问题\n\n如果在安装过程中遇到问题，可能是由上游依赖包导致的，请参考这些包的文档：\n\n* [torch](https://pytorch.org/get-started/locally/)\n* [Ascend/pytorch](https://github.com/Ascend/pytorch)\n* [sentencepiece](https://github.com/google/sentencepiece)\n* [cmake](https://cmake.org)\n"
  },
  {
    "path": "docs/zh/Pipeline_Usage/VRAM_management.md",
    "content": "# 显存管理\n\n显存管理是 `DiffSynth-Studio` 的特色功能，能够让低显存的 GPU 能够运行参数量巨大的模型推理。本文档以 Qwen-Image 为例，介绍显存管理方案的使用。\n\n## 基础推理\n\n以下代码中没有启用任何显存管理，显存占用 56G，作为参考。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## CPU Offload\n\n由于模型 `Pipeline` 包括多个组件，这些组件并非同时调用的，因此我们可以在某些组件不需要参与计算时将其移至内存，减少显存占用，以下代码可以实现这一逻辑，显存占用 40G。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## FP8 量化\n\n在 CPU Offload 的基础上，我们进一步启用 FP8 量化来减少显存需求，以下代码可以令模型参数以 FP8 精度存储在显存中，并在推理时临时转为 BF16 精度计算，显存占用 21G。但这种量化方案有微小的图像质量下降问题。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n> Q: 为什么要在推理时临时转为 BF16 精度，而不是以 FP8 精度计算？\n> \n> A: FP8 的原生计算仅在 Hopper 架构的 GPU（例如 H20）支持，且计算误差很大，我们目前暂不开放 FP8 精度计算。目前的 FP8 量化仅能减少显存占用，不会提高计算速度。\n\n## 动态显存管理\n\n在 CPU Offload 中，我们对模型组件进行控制，事实上，我们支持做到 Layer 级别的 Offload，将一个模型拆分为多个 Layer，令一部分常驻显存，令一部分存储在内存中按需移至显存计算。这一功能需要模型开发者针对每个模型提供详细的显存管理方案，相关配置在 `diffsynth/configs/vram_management_module_maps.py` 中。\n\n通过在 `Pipeline` 中增加 `vram_limit` 参数，框架可以自动感知设备的剩余显存并决定如何拆分模型到显存和内存中。`vram_limit` 越小，占用显存越少，速度越慢。\n* `vram_limit=None` 时，即默认状态，框架认为显存无限，动态显存管理是不启用的\n* `vram_limit=10` 时，框架会在显存占用超过 10G 之后限制模型，将超出的部分移至内存中存储。\n* `vram_limit=0` 时，框架会尽全力减少显存占用，所有模型参数都存储在内存中，仅在必要时移至显存计算\n\n在显存不足以运行模型推理的情况下，框架会试图超出 `vram_limit` 的限制从而让模型推理运行下去，因此显存管理框架并不能总是保证占用的显存小于 `vram_limit`，我们建议将其设置为略小于实际可用显存的数值，例如 GPU 显存为 16G 时，设置为 `vram_limit=15.5`。`PyTorch` 中可用 `torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3)` 获取 GPU 的显存。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## Disk Offload\n\n在更为极端的情况下，当内存也不足以存储整个模型时，Disk Offload 功能可以让模型参数惰性加载，即，模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时，我们建议使用高速的 SSD 硬盘。\n\nDisk Offload 是极为特殊的显存管理方案，只支持 `.safetensors` 格式文件，不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件，不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=10,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n## 更多使用方式\n\n`vram_config` 中的信息可自行填写，例如不开 FP8 量化的 Disk Offload：\n\n```python\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n```\n\n具体地，显存管理模块会将模型的 Layer 分为以下四种状态：\n\n* Offload：短期内不调用这个模型，这个状态由 `Pipeline` 控制切换\n* Onload：接下来随时要调用这个模型，这个状态由 `Pipeline` 控制切换\n* Preparing：Onload 和 Computation 的中间状态，在显存允许的前提下的暂存状态，这个状态由显存管理机制控制切换，当且仅当【vram_limit 设置为无限制】或【vram_limit 已设置且有空余显存】时会进入这一状态\n* Computation：模型正在计算过程中，这个状态由显存管理机制控制切换，仅在 `forward` 中临时进入\n\n如果你是模型开发者，希望自行控制某个模型的显存管理粒度，请参考[../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md)。\n\n## 最佳实践\n\n* 显存足够 -> 使用[基础推理](#基础推理)\n* 显存不足\n    * 内存足够 -> 使用[动态显存管理](#动态显存管理)\n    * 内存不足 -> 使用[Disk Offload](#disk-offload)\n"
  },
  {
    "path": "docs/zh/QA.md",
    "content": "# 常见问题\n\n## 为什么训练框架不支持 batch size > 1？\n\n* **更大的 batch size 已无法实现显著加速**：由于 flash attention 等加速技术已经充分提高了 GPU 的利用率，因此更大的 batch size 只会带来更大的显存占用，无法带来显著加速。在 Stable Diffusion 1.5 这类小模型上的经验已不再适用于最新的大模型。\n* **更大的 batch size 可以用其他方案实现**：多 GPU 训练和 Gradient Accumulation 都可以在数学意义上等价地实现更大的 batch size。\n* **更大的 batch size 与框架的通用性设计相悖**：我们希望构建通用的训练框架，大量模型无法适配更大的 batch size，例如不同长度的文本编码、不同分辨率的图像等，都是无法合并为更大的 batch 的。\n\n## 为什么不删除某些模型中的冗余参数？\n\n在部分模型中，模型存在冗余参数，例如 Qwen-Image 的 DiT 模型最后一层的文本部分，这部分参数不会参与任何计算，这是模型开发者留下的小 bug。直接将其设置为可训练时还会在多 GPU 训练中出现报错。\n\n为了与开源社区中其他模型保持兼容性，我们决定保留这些参数。这些冗余参数在多 GPU 训练中可以通过 `--find_unused_parameters` 参数避免报错。\n\n## 为什么 FP8 量化没有任何加速效果？\n\n原生 FP8 计算需要依赖 Hopper 架构的 GPU，同时在计算精度上有较大误差，目前仍然是不成熟的技术，因此本项目不支持原生 FP8 计算。\n\n显存管理中的 FP8 计算是指将模型参数以 FP8 精度存储在内存或显存中，在需要计算时临时转换为其他精度，因此仅能减少显存占用，没有加速效果。\n\n## 为什么训练框架不支持原生 FP8 精度训练？\n\n即使硬件条件允许，我们目前也没有任何支持原生 FP8 精度训练的规划。\n\n* 目前原生 FP8 精度训练的主要挑战是梯度爆炸导致的精度溢出，为了保证训练的稳定性，需针对性地重新设计模型结构，然而目前还没有任何模型开发者愿意这么做。\n* 此外，使用原生 FP8 精度训练的模型，在推理时若没有 Hopper 架构 GPU，则只能以 BF16 精度进行计算，理论上其生成效果反而不如 FP8。\n\n因此，原生 FP8 精度训练技术是极不成熟的，我们静观开源社区的技术发展。\n\n## 如何在推理时动态加载 LoRA 模型？\n\n我们支持 LoRA 模型的两种加载方式，详见[LoRA 加载](./Pipeline_Usage/Model_Inference.md#加载-lora)：\n\n* 冷加载：当基础模型未开启[显存管理](./Pipeline_Usage/VRAM_management.md)时，LoRA 会融合进基础模型权重，此时推理速度没有变化，LoRA 加载后无法卸载。\n* 热加载：当基础模型开启[显存管理](./Pipeline_Usage/VRAM_management.md)时，LoRA 不会融合进基础模型权重，此时推理速度会变慢，LoRA 加载后可通过 `pipe.clear_lora()` 卸载。\n"
  },
  {
    "path": "docs/zh/README.md",
    "content": "# DiffSynth-Studio 文档\n\n欢迎来到 Diffusion 模型的魔法世界！`DiffSynth-Studio` 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望构建一个通用的 Diffusion 模型框架，以框架建设孵化技术创新，凝聚开源社区的力量，探索生成式模型技术的边界！\n\n<details>\n\n<summary>文档阅读导引</summary>\n\n```mermaid\ngraph LR;\n    我想要使用模型进行推理和训练-->sec1[Section 1: 上手使用];\n    我想要使用模型进行推理和训练-->sec2[Section 2: 模型详解];\n    我想要使用模型进行推理和训练-->sec3[Section 3: 训练框架];\n    我想要基于此框架进行二次开发-->sec3[Section 3: 训练框架];\n    我想要基于此框架进行二次开发-->sec4[Section 4: 模型接入];\n    我想要基于此框架进行二次开发-->sec5[Section 5: API 参考];\n    我想要基于本项目探索新的技术-->sec4[Section 4: 模型接入];\n    我想要基于本项目探索新的技术-->sec5[Section 5: API 参考];\n    我想要基于本项目探索新的技术-->sec6[Section 6: 学术导引];\n    我遇到了问题-->sec7[Section 7: 常见问题];\n```\n\n</details>\n\n## Section 1: 上手使用\n\n本节介绍 `DiffSynth-Studio` 的基本使用方式，包括如何启用显存管理从而在极低显存的 GPU 上进行推理，以及如何训练任意基础模型、LoRA、ControlNet 等模型。\n\n* [安装依赖](./Pipeline_Usage/Setup.md)\n* [模型推理](./Pipeline_Usage/Model_Inference.md)\n* [显存管理](./Pipeline_Usage/VRAM_management.md)\n* [模型训练](./Pipeline_Usage/Model_Training.md)\n* [环境变量](./Pipeline_Usage/Environment_Variables.md)\n* [GPU/NPU 支持](./Pipeline_Usage/GPU_support.md)\n\n## Section 2: 模型详解\n\n本节介绍 `DiffSynth-Studio` 所支持的 Diffusion 模型，部分模型 Pipeline 具备可控生成、并行加速等特色功能。\n\n* [FLUX.1](./Model_Details/FLUX.md)\n* [Wan](./Model_Details/Wan.md)\n* [Qwen-Image](./Model_Details/Qwen-Image.md)\n* [FLUX.2](./Model_Details/FLUX2.md)\n* [Z-Image](./Model_Details/Z-Image.md)\n* [Anima](./Model_Details/Anima.md)\n* [LTX-2](./Model_Details/LTX-2.md)\n\n## Section 3: 训练框架\n\n本节介绍 `DiffSynth-Studio` 中训练框架的设计思路，帮助开发者理解 Diffusion 模型训练算法的原理。\n\n* [Diffusion 模型基本原理](./Training/Understanding_Diffusion_models.md)\n* [标准监督训练](./Training/Supervised_Fine_Tuning.md)\n* [在训练中启用 FP8 精度](./Training/FP8_Precision.md)\n* [端到端的蒸馏加速训练](./Training/Direct_Distill.md)\n* [两阶段拆分训练](./Training/Split_Training.md)\n* [差分 LoRA 训练](./Training/Differential_LoRA.md)\n\n## Section 4: 模型接入\n\n本节介绍如何将模型接入 `DiffSynth-Studio` 从而使用框架基础功能，帮助开发者为本项目提供新模型的支持，或进行私有化模型的推理和训练。\n\n* [接入模型结构](./Developer_Guide/Integrating_Your_Model.md)\n* [接入 Pipeline](./Developer_Guide/Building_a_Pipeline.md)\n* [接入细粒度显存管理](./Developer_Guide/Enabling_VRAM_management.md)\n* [接入模型训练](./Developer_Guide/Training_Diffusion_Models.md)\n\n## Section 5: API 参考\n\n本节介绍 `DiffSynth-Studio` 中的独立核心模块 `diffsynth.core`，介绍内部的功能是如何设计和运作的，开发者如有需要，可将其中的功能模块用于其他代码库的开发中。\n\n* [`diffsynth.core.attention`](./API_Reference/core/attention.md): 注意力机制实现\n* [`diffsynth.core.data`](./API_Reference/core/data.md): 数据处理算子与通用数据集\n* [`diffsynth.core.gradient`](./API_Reference/core/gradient.md): 梯度检查点\n* [`diffsynth.core.loader`](./API_Reference/core/loader.md): 模型下载与加载\n* [`diffsynth.core.vram`](./API_Reference/core/vram.md): 显存管理\n\n## Section 6: 学术导引\n\n本节介绍如何利用 `DiffSynth-Studio` 训练新的模型，帮助科研工作者探索新的模型技术。\n\n* [从零开始训练模型](./Research_Tutorial/train_from_scratch.md)\n* [推理改进优化技术](./Research_Tutorial/inference_time_scaling.md)\n* 设计可控生成模型【coming soon】\n* 创建新的训练范式【coming soon】\n\n## Section 7: 常见问题\n\n本节总结了开发者常见的问题，如果你在使用和开发中遇到了问题，请参考本节内容，如果仍无法解决，请到 GitHub 上给我们提 issue。\n\n* [常见问题](./QA.md)\n"
  },
  {
    "path": "docs/zh/Research_Tutorial/inference_time_scaling.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8db54992\",\n   \"metadata\": {},\n   \"source\": [\n    \"# 推理改进优化技术\\n\",\n    \"\\n\",\n    \"DiffSynth-Studio 旨在以基础框架驱动技术创新。本文以 Inference-time scaling 为例，展示如何基于 DiffSynth-Studio 构建免训练（Training-free）的图像生成增强方案。\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0911cad4\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 1. 图像质量量化\\n\",\n    \"\\n\",\n    \"首先，我们需要找到一个指标来量化图像生成模型生成的图像质量。最简单直接的方案是人工打分，但这样做的成本太高，无法大规模使用。不过，收集人工打分后，训练一个图像分类模型来预测人类的打分结果，是完全可行的。PickScore [[1]](https://arxiv.org/abs/2305.01569) 就是这样一个模型，运行下面的代码，将会自动下载并加载 [PickScore 模型](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)。\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4faca4ca\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from modelscope import AutoProcessor, AutoModel\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"class PickScore(torch.nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.processor = AutoProcessor.from_pretrained(\\\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\\\")\\n\",\n    \"        self.model = AutoModel.from_pretrained(\\\"AI-ModelScope/PickScore_v1\\\").eval().to(\\\"cuda\\\")\\n\",\n    \"\\n\",\n    \"    def forward(self, image, prompt):\\n\",\n    \"        image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\\\"pt\\\").to(\\\"cuda\\\")\\n\",\n    \"        text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\\\"pt\\\").to(\\\"cuda\\\")\\n\",\n    \"        with torch.inference_mode():\\n\",\n    \"            image_embs = self.model.get_image_features(**image_inputs).pooler_output\\n\",\n    \"            image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\\n\",\n    \"            text_embs = self.model.get_text_features(**text_inputs).pooler_output\\n\",\n    \"            text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\\n\",\n    \"            score = (text_embs @ image_embs.T).flatten().item()\\n\",\n    \"        return score\\n\",\n    \"\\n\",\n    \"reward_model = PickScore()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5f807cec\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Inference-time Scaling 技术\\n\",\n    \"\\n\",\n    \"Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) 是一类有趣的技术，旨在通过增加推理时的计算量来提升生成结果的质量。例如，在语言模型中，[Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B)、[deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) 等模型通过“思考模式”引导模型花更多时间仔细思考，让回答结果更准确。接下来我们以模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 为例，探讨如何为图像生成模型设计 Inference-time Scaling 方案。\\n\",\n    \"\\n\",\n    \"> 在开始前，我们稍微改造了 `Flux2ImagePipeline` 的代码，使其能够根据输入的特定高斯噪声矩阵进行初始化，便于复现结果，详见 [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py) 中的 `Flux2Unit_NoiseInitializer`。\\n\",\n    \"\\n\",\n    \"运行以下代码，加载模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)。\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c5818a87\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\\n\",\n    \"\\n\",\n    \"pipe = Flux2ImagePipeline.from_pretrained(\\n\",\n    \"    torch_dtype=torch.bfloat16,\\n\",\n    \"    device=\\\"cuda\\\",\\n\",\n    \"    model_configs=[\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"text_encoder/*.safetensors\\\"),\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"transformer/*.safetensors\\\"),\\n\",\n    \"        ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"vae/diffusion_pytorch_model.safetensors\\\"),\\n\",\n    \"    ],\\n\",\n    \"    tokenizer_config=ModelConfig(model_id=\\\"black-forest-labs/FLUX.2-klein-4B\\\", origin_file_pattern=\\\"tokenizer/\\\"),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f58e9945\",\n   \"metadata\": {},\n   \"source\": [\n    \"用提示词 `\\\"sketch, a cat\\\"` 生成一只素描猫猫，并用 PickScore 模型打分。\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6ea2d258\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def evaluate_noise(noise, pipe, reward_model, prompt):\\n\",\n    \"    # Generate an image and compute the score.\\n\",\n    \"    image = pipe(\\n\",\n    \"        prompt=prompt,\\n\",\n    \"        num_inference_steps=4,\\n\",\n    \"        initial_noise=noise,\\n\",\n    \"        progress_bar_cmd=lambda x: x,\\n\",\n    \"    )\\n\",\n    \"    score = reward_model(image, prompt)\\n\",\n    \"    return score\\n\",\n    \"\\n\",\n    \"torch.manual_seed(1)\\n\",\n    \"prompt = \\\"sketch, a cat\\\"\\n\",\n    \"noise = pipe.generate_noise((1, 128, 64, 64), rand_device=\\\"cuda\\\", rand_torch_dtype=pipe.torch_dtype)\\n\",\n    \"\\n\",\n    \"image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_1, prompt))\\n\",\n    \"image_1\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5e11694e\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.1 Best-of-N 随机搜索\\n\",\n    \"\\n\",\n    \"模型的生成结果具有一定的随机性，如果用不同的随机种子，生成的图像结果也是不同的，有时图像质量高，有时图像质量低。那么，我们有一个简单的 Inference-time scaling 方案：使用多个不同的随机种子分别生成图像，然后利用 PickScore 进行打分，只保留分数最高的那一张。\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"241f10d2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tqdm import tqdm\\n\",\n    \"\\n\",\n    \"def random_search(base_latents, objective_reward_fn, total_eval_budget):\\n\",\n    \"    # Search for the noise randomly.\\n\",\n    \"    best_noise = base_latents\\n\",\n    \"    best_score = objective_reward_fn(base_latents)\\n\",\n    \"    for it in tqdm(range(total_eval_budget - 1)):\\n\",\n    \"        noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\\n\",\n    \"        score = objective_reward_fn(noise)\\n\",\n    \"        if score > best_score:\\n\",\n    \"            best_score, best_noise = score, noise\\n\",\n    \"    return best_noise\\n\",\n    \"\\n\",\n    \"best_noise = random_search(\\n\",\n    \"    base_latents=noise,\\n\",\n    \"    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\\n\",\n    \"    total_eval_budget=50,\\n\",\n    \")\\n\",\n    \"image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_2, prompt))\\n\",\n    \"image_2\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8e9bf966\",\n   \"metadata\": {},\n   \"source\": [\n    \"我们可以清晰地看到，经过多次随机搜索后，最终选出的猫猫毛发细节更加丰富，PickScore 分数也有明显提升。但这种暴力的随机搜索效率极低，生成时间成倍增长，且很容易触及质量上限。因此，我们希望能够找到一种更高效的搜索方法，在同等计算预算下达到更高的分数。\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c9578349\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.2 SES 搜索\\n\",\n    \"\\n\",\n    \"为了突破随机搜索的瓶颈，我们引入了 SES (Spectral Evolution Search) 算法 [[3]](https://arxiv.org/abs/2602.03208)，详细的代码位于 [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses)。\\n\",\n    \"\\n\",\n    \"扩散模型生成的图像，很大程度上由初始噪声的低频分量决定。SES 算法通过小波变换将高斯噪声分解，固定高频细节，专门针对低频部分使用交叉熵方法进行演化搜索，能以更高的效率找到优质的初始噪声。\\n\",\n    \"\\n\",\n    \"运行下面的代码，即可使用 SES 更高效地搜索最佳的高斯噪声矩阵。\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"adeed2aa\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from diffsynth.utils.ses import ses_search\\n\",\n    \"\\n\",\n    \"best_noise = ses_search(\\n\",\n    \"    base_latents=noise,\\n\",\n    \"    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\\n\",\n    \"    total_eval_budget=50,\\n\",\n    \")\\n\",\n    \"image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\\n\",\n    \"print(\\\"Score:\\\", reward_model(image_3, prompt))\\n\",\n    \"image_3\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"940a97f1\",\n   \"metadata\": {},\n   \"source\": [\n    \"可以观察到，在同样的计算预算下，相比于随机搜索，SES 的结果在 PickScore 得分上取得了显著的提升。“素描猫猫”展现出了更精致的整体构图以及更具层次感的明暗对比。\\n\",\n    \"\\n\",\n    \"Inference-time scaling 能够以更长推理时间为代价获得更高的图像质量，那么它生成的图像数据也可以用 DPO [[4]](https://arxiv.org/abs/2311.12908)、差分训练 [[5]](https://arxiv.org/abs/2412.12888) 等方式赋予模型自身，那就是另外一个有趣的探索方向了。\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"dzj8\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.19\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/zh/Research_Tutorial/inference_time_scaling.md",
    "content": "# 推理改进优化技术\n\nDiffSynth-Studio 旨在以基础框架驱动技术创新。本文以 Inference-time scaling 为例，展示如何基于 DiffSynth-Studio 构建免训练（Training-free）的图像生成增强方案。\n\nNotebook: https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/zh/Research_Tutorial/inference_time_scaling.ipynb\n\n## 1. 图像质量量化\n\n首先，我们需要找到一个指标来量化图像生成模型生成的图像质量。最简单直接的方案是人工打分，但这样做的成本太高，无法大规模使用。不过，收集人工打分后，训练一个图像分类模型来预测人类的打分结果，是完全可行的。PickScore [[1]](https://arxiv.org/abs/2305.01569) 就是这样一个模型，运行下面的代码，将会自动下载并加载 [PickScore 模型](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)。\n\n```python\nfrom modelscope import AutoProcessor, AutoModel\nimport torch\n\nclass PickScore(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.processor = AutoProcessor.from_pretrained(\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\")\n        self.model = AutoModel.from_pretrained(\"AI-ModelScope/PickScore_v1\").eval().to(\"cuda\")\n\n    def forward(self, image, prompt):\n        image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n        text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n        with torch.inference_mode():\n            image_embs = self.model.get_image_features(**image_inputs).pooler_output\n            image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n            text_embs = self.model.get_text_features(**text_inputs).pooler_output\n            text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n            score = (text_embs @ image_embs.T).flatten().item()\n        return score\n\nreward_model = PickScore()\n```\n\n## 2. Inference-time Scaling 技术\n\nInference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) 是一类有趣的技术，旨在通过增加推理时的计算量来提升生成结果的质量。例如，在语言模型中，[Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B)、[deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) 等模型通过“思考模式”引导模型花更多时间仔细思考，让回答结果更准确。接下来我们以模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 为例，探讨如何为图像生成模型设计 Inference-time Scaling 方案。\n\n> 在开始前，我们稍微改造了 `Flux2ImagePipeline` 的代码，使其能够根据输入的特定高斯噪声矩阵进行初始化，便于复现结果，详见 [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py) 中的 `Flux2Unit_NoiseInitializer`。\n\n运行以下代码，加载模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)。\n\n```python\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\n```\n\n用提示词 `\"sketch, a cat\"` 生成一只素描猫猫，并用 PickScore 模型打分。\n\n```python\ndef evaluate_noise(noise, pipe, reward_model, prompt):\n    # Generate an image and compute the score.\n    image = pipe(\n        prompt=prompt,\n        num_inference_steps=4,\n        initial_noise=noise,\n        progress_bar_cmd=lambda x: x,\n    )\n    score = reward_model(image, prompt)\n    return score\n\ntorch.manual_seed(1)\nprompt = \"sketch, a cat\"\nnoise = pipe.generate_noise((1, 128, 64, 64), rand_device=\"cuda\", rand_torch_dtype=pipe.torch_dtype)\n\nimage_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\nprint(\"Score:\", reward_model(image_1, prompt))\nimage_1\n```\n\n![Image](https://github.com/user-attachments/assets/b6546c6d-b368-4463-b703-d561a9134ba0)\n\n### 2.1 Best-of-N 随机搜索\n\n模型的生成结果具有一定的随机性，如果用不同的随机种子，生成的图像结果也是不同的，有时图像质量高，有时图像质量低。那么，我们有一个简单的 Inference-time scaling 方案：使用多个不同的随机种子分别生成图像，然后利用 PickScore 进行打分，只保留分数最高的那一张。\n\n```python\nfrom tqdm import tqdm\n\ndef random_search(base_latents, objective_reward_fn, total_eval_budget):\n    # Search for the noise randomly.\n    best_noise = base_latents\n    best_score = objective_reward_fn(base_latents)\n    for it in tqdm(range(total_eval_budget - 1)):\n        noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\n        score = objective_reward_fn(noise)\n        if score > best_score:\n            best_score, best_noise = score, noise\n    return best_noise\n\nbest_noise = random_search(\n    base_latents=noise,\n    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n    total_eval_budget=50,\n)\nimage_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\nprint(\"Score:\", reward_model(image_2, prompt))\nimage_2\n```\n\n![Image](https://github.com/user-attachments/assets/b8dba70a-daa8-4368-8f32-a6c150daecb5)\n\n我们可以清晰地看到，经过多次随机搜索后，最终选出的猫猫毛发细节更加丰富，PickScore 分数也有明显提升。但这种暴力的随机搜索效率极低，生成时间成倍增长，且很容易触及质量上限。因此，我们希望能够找到一种更高效的搜索方法，在同等计算预算下达到更高的分数。\n\n### 2.2 SES 搜索\n\n为了突破随机搜索的瓶颈，我们引入了 SES (Spectral Evolution Search) 算法 [[3]](https://arxiv.org/abs/2602.03208)，详细的代码位于 [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses)。\n\n扩散模型生成的图像，很大程度上由初始噪声的低频分量决定。SES 算法通过小波变换将高斯噪声分解，固定高频细节，专门针对低频部分使用交叉熵方法进行演化搜索，能以更高的效率找到优质的初始噪声。\n\n运行下面的代码，即可使用 SES 更高效地搜索最佳的高斯噪声矩阵。\n\n```python\nfrom diffsynth.utils.ses import ses_search\n\nbest_noise = ses_search(\n    base_latents=noise,\n    objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n    total_eval_budget=50,\n)\nimage_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\nprint(\"Score:\", reward_model(image_3, prompt))\nimage_3\n```\n\n![Image](https://github.com/user-attachments/assets/9a3f7598-3812-46d2-b333-cd65e49886ab)\n\n可以观察到，在同样的计算预算下，相比于随机搜索，SES 的结果在 PickScore 得分上取得了显著的提升。“素描猫猫”展现出了更精致的整体构图以及更具层次感的明暗对比。\n\nInference-time scaling 能够以更长推理时间为代价获得更高的图像质量，那么它生成的图像数据也可以用 DPO [[4]](https://arxiv.org/abs/2311.12908)、差分训练 [[5]](https://arxiv.org/abs/2412.12888) 等方式赋予模型自身，那就是另外一个有趣的探索方向了。\n"
  },
  {
    "path": "docs/zh/Research_Tutorial/train_from_scratch.md",
    "content": "# 从零开始训练模型\n\nDiffSynth-Studio 的训练引擎支持从零开始训练基础模型，本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。\n\n## 1. 构建模型结构\n\n### 1.1 Diffusion 模型\n\n从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206)，Diffusion 的主流模型结构经历了多次演变。通常，一个 Diffusion 模型的输入包括：\n\n* 图像张量（`latents`）：图像的编码，由 VAE 模型产生，含有部分噪声\n* 文本张量（`prompt_embeds`）：文本的编码，由文本编码器产生\n* 时间步（`timestep`）：标量，用于标记当前处于 Diffusion 过程的哪个阶段\n\n模型的输出是与图像张量形状相同的张量，表示模型预测的去噪方向，关于 Diffusion 模型理论的细节，请参考 [Diffusion 模型基本原理](../Training/Understanding_Diffusion_models.md)。在本文中，我们构建一个仅含 0.1B 参数的 DiT 模型：`AAADiT`。\n\n<details>\n<summary>模型结构代码</summary>\n\n```python\nimport torch, accelerate\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat\n\nfrom transformers import AutoProcessor, AutoTokenizer\nfrom diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model\nfrom diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task\nfrom diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom diffsynth.models.general_modules import TimestepEmbeddings\nfrom diffsynth.models.z_image_text_encoder import ZImageTextEncoder\nfrom diffsynth.models.flux2_vae import Flux2VAE\n\n\nclass AAAPositionalEmbedding(torch.nn.Module):\n    def __init__(self, height=16, width=16, dim=1024):\n        super().__init__()\n        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))\n        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))\n\n    def forward(self, image, text):\n        height, width = image.shape[-2:]\n        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)\n        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode=\"bilinear\")\n        image_emb = rearrange(image_emb, \"B C H W -> B (H W) C\")\n        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)\n        text_emb = repeat(text_emb, \"C -> B L C\", B=text.shape[0], L=text.shape[1])\n        emb = torch.concat([image_emb, text_emb], dim=1)\n        return emb\n\n\nclass AAABlock(torch.nn.Module):\n    def __init__(self, dim=1024, num_heads=32):\n        super().__init__()\n        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.to_q = torch.nn.Linear(dim, dim)\n        self.to_k = torch.nn.Linear(dim, dim)\n        self.to_v = torch.nn.Linear(dim, dim)\n        self.to_out = torch.nn.Linear(dim, dim)\n        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.ff = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*3),\n            torch.nn.SiLU(),\n            torch.nn.Linear(dim*3, dim),\n        )\n        self.to_gate = torch.nn.Linear(dim, dim * 2)\n        self.num_heads = num_heads\n\n    def attention(self, emb, pos_emb):\n        emb = self.norm_attn(emb + pos_emb)\n        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)\n        emb = attention_forward(\n            q, k, v,\n            q_pattern=\"b s (n d)\", k_pattern=\"b s (n d)\", v_pattern=\"b s (n d)\", out_pattern=\"b s (n d)\",\n            dims={\"n\": self.num_heads},\n        )\n        emb = self.to_out(emb)\n        return emb\n    \n    def feed_forward(self, emb, pos_emb):\n        emb = self.norm_mlp(emb + pos_emb)\n        emb = self.ff(emb)\n        return emb\n    \n    def forward(self, emb, pos_emb, t_emb):\n        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)\n        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)\n        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)\n        return emb\n\n\nclass AAADiT(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.pos_embedder = AAAPositionalEmbedding(dim=dim)\n        self.timestep_embedder = TimestepEmbeddings(256, dim)\n        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))\n        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))\n        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])\n        self.proj_out = torch.nn.Linear(dim, 128)\n\n    def forward(\n        self,\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        pos_emb = self.pos_embedder(latents, prompt_embeds)\n        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)\n        image = self.image_embedder(rearrange(latents, \"B C H W -> B (H W) C\"))\n        text = self.text_embedder(prompt_embeds)\n        emb = torch.concat([image, text], dim=1)\n        for block_id, block in enumerate(self.blocks):\n            emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                emb=emb,\n                pos_emb=pos_emb,\n                t_emb=t_emb,\n            )\n        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]\n        emb = self.proj_out(emb)\n        emb = rearrange(emb, \"B (H W) C -> B C H W\", W=latents.shape[-1])\n        return emb\n```\n\n</details>\n\n### 1.2 编解码器模型\n\n除了用于去噪的 Diffusion 模型以外，我们还需要另外两个模型：\n\n* 文本编码器：用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。\n* VAE 编解码器：编码器部分用于将图像编码为张量，解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。\n\n这两个模型的结构都已集成在 DiffSynth-Studio 中，分别位于 [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py)，因此我们不需要修改任何代码。\n\n## 2. 构建 Pipeline\n\n我们在文档 [接入 Pipeline](../Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline，对于本文中的模型，我们也需要构建一个 Pipeline，连接文本编码器、Diffusion 模型、VAE 编解码器。\n\n<details>\n<summary>Pipeline 代码</summary>\n\n```python\nclass AAAImagePipeline(BasePipeline):\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.2\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: AAADiT = None\n        self.vae: Flux2VAE = None\n        self.tokenizer: AutoProcessor = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            AAAUnit_PromptEmbedder(),\n            AAAUnit_NoiseInitializer(),\n            AAAUnit_InputImageEmbedder(),\n        ]\n        self.model_fn = model_fn_aaa\n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = None,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"aaa_dit\")\n        pipe.vae = model_pool.fetch_model(\"flux2_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)\n\n        # Parameters\n        inputs_posi = {\"prompt\": prompt}\n        inputs_nega = {\"negative_prompt\": negative_prompt}\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"])\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass AAAUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_embeds\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n        self.hidden_states_layers = (-1,)\n\n    def process(self, pipe: AAAImagePipeline, prompt):\n        pipe.load_models_to_device(self.onload_model_names)\n        text = pipe.tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            tokenize=False,\n            add_generation_prompt=True,\n            enable_thinking=False,\n        )\n        inputs = pipe.tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128).to(pipe.device)\n        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)\n        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)\n        return {\"prompt_embeds\": prompt_embeds}\n\n\nclass AAAUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\nclass AAAUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: AAAImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae.encode(image)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\ndef model_fn_aaa(\n    dit: AAADiT,\n    latents=None,\n    prompt_embeds=None,\n    timestep=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    model_output = dit(\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    return model_output\n```\n\n</details>\n\n## 3. 准备数据集\n\n为了快速验证训练效果，我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1)，这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh)，包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集，请参考文档 [准备数据集](../Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](../API_Reference/core/data.md)。\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data\n```\n\n### 4. 开始训练\n\n训练过程可使用 Pipeline 快速实现，我们已将完整的代码放在 [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/zh/Research_Tutorial/train_from_scratch.py)，可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。\n\n如需开启多 GPU 并行训练，请运行 `accelerate config` 设置相关参数，然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。\n\n这个训练脚本没有设置停止条件，请在需要时手动关闭。模型在训练大约 6 万步后收敛，单 GPU 训练需要 10～20 小时。\n\n\n<details>\n<summary>训练代码</summary>\n\n```python\nclass AAATrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        super().__init__()\n        self.pipe = AAAImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n                ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n        )\n        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)\n        self.pipe.freeze_except([\"dit\"])\n        self.pipe.scheduler.set_timesteps(1000, training=True)\n\n    def forward(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            \"cfg_scale\": 1,\n            \"use_gradient_checkpointing\": False,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n\n\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)\n    dataset = UnifiedDataset(\n        base_path=\"data/images\",\n        metadata_path=\"data/metadata_merged.csv\",\n        max_data_items=10000000,\n        data_file_keys=(\"image\",),\n        main_data_operator=UnifiedDataset.default_image_operator(base_path=\"data/images\", height=256, width=256)\n    )\n    model = AAATrainingModule(device=accelerator.device)\n    model_logger = ModelLogger(\n        \"models/AAA/v1\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=2e-4,\n        num_workers=4,\n        save_steps=50000,\n        num_epochs=999999,\n    )\n```\n\n</details>\n\n## 5. 验证训练效果\n\n如果你不想等待模型训练完成，可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。\n\n```shell\nmodelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel\n```\n\n加载模型\n\n```python\nfrom diffsynth import load_model\n\npipe = AAAImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n)\npipe.dit = load_model(AAADiT, \"models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors\", torch_dtype=torch.bfloat16, device=\"cuda\")\n```\n\n模型推理，生成第一世代宝可梦“御三家”，此时模型生成的图像内容与训练数据基本一致。\n\n```python\nfor seed, prompt in enumerate([\n    \"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws\",\n    \"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws\",\n    \"蓝色，米色，棕色，乌龟，水系，龟壳，大眼睛，短四肢，卷曲尾巴\",\n]):\n    image = pipe(\n        prompt=prompt,\n        negative_prompt=\" \",\n        num_inference_steps=30,\n        cfg_scale=10,\n        seed=seed,\n        height=256, width=256,\n    )\n    image.save(f\"image_{seed}.jpg\")\n```\n\n|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|\n|-|-|-|\n\n模型推理，生成具有“锐利爪子”的宝可梦，此时不同的随机种子能够产生不同的图像结果。\n\n```python\nfor seed, prompt in enumerate([\n    \"sharp claws\",\n    \"sharp claws\",\n    \"sharp claws\",\n]):\n    image = pipe(\n        prompt=prompt,\n        negative_prompt=\" \",\n        num_inference_steps=30,\n        cfg_scale=10,\n        seed=seed+4,\n        height=256, width=256,\n    )\n    image.save(f\"image_sharp_claws_{seed}.jpg\")\n```\n\n|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|\n|-|-|-|\n\n现在，我们获得了一个 0.1B 的小型文生图模型，这个模型已经能够生成 151 个宝可梦，但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量，你就可以训练出一个更强大的文生图模型！\n"
  },
  {
    "path": "docs/zh/Research_Tutorial/train_from_scratch.py",
    "content": "import torch, accelerate\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat\n\nfrom transformers import AutoProcessor, AutoTokenizer\nfrom diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model\nfrom diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task\nfrom diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit\nfrom diffsynth.models.general_modules import TimestepEmbeddings\nfrom diffsynth.models.z_image_text_encoder import ZImageTextEncoder\nfrom diffsynth.models.flux2_vae import Flux2VAE\n\n\nclass AAAPositionalEmbedding(torch.nn.Module):\n    def __init__(self, height=16, width=16, dim=1024):\n        super().__init__()\n        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))\n        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))\n\n    def forward(self, image, text):\n        height, width = image.shape[-2:]\n        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)\n        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode=\"bilinear\")\n        image_emb = rearrange(image_emb, \"B C H W -> B (H W) C\")\n        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)\n        text_emb = repeat(text_emb, \"C -> B L C\", B=text.shape[0], L=text.shape[1])\n        emb = torch.concat([image_emb, text_emb], dim=1)\n        return emb\n\n\nclass AAABlock(torch.nn.Module):\n    def __init__(self, dim=1024, num_heads=32):\n        super().__init__()\n        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.to_q = torch.nn.Linear(dim, dim)\n        self.to_k = torch.nn.Linear(dim, dim)\n        self.to_v = torch.nn.Linear(dim, dim)\n        self.to_out = torch.nn.Linear(dim, dim)\n        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)\n        self.ff = torch.nn.Sequential(\n            torch.nn.Linear(dim, dim*3),\n            torch.nn.SiLU(),\n            torch.nn.Linear(dim*3, dim),\n        )\n        self.to_gate = torch.nn.Linear(dim, dim * 2)\n        self.num_heads = num_heads\n\n    def attention(self, emb, pos_emb):\n        emb = self.norm_attn(emb + pos_emb)\n        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)\n        emb = attention_forward(\n            q, k, v,\n            q_pattern=\"b s (n d)\", k_pattern=\"b s (n d)\", v_pattern=\"b s (n d)\", out_pattern=\"b s (n d)\",\n            dims={\"n\": self.num_heads},\n        )\n        emb = self.to_out(emb)\n        return emb\n    \n    def feed_forward(self, emb, pos_emb):\n        emb = self.norm_mlp(emb + pos_emb)\n        emb = self.ff(emb)\n        return emb\n    \n    def forward(self, emb, pos_emb, t_emb):\n        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)\n        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)\n        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)\n        return emb\n\n\nclass AAADiT(torch.nn.Module):\n    def __init__(self, dim=1024):\n        super().__init__()\n        self.pos_embedder = AAAPositionalEmbedding(dim=dim)\n        self.timestep_embedder = TimestepEmbeddings(256, dim)\n        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))\n        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))\n        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])\n        self.proj_out = torch.nn.Linear(dim, 128)\n\n    def forward(\n        self,\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=False,\n        use_gradient_checkpointing_offload=False,\n    ):\n        pos_emb = self.pos_embedder(latents, prompt_embeds)\n        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)\n        image = self.image_embedder(rearrange(latents, \"B C H W -> B (H W) C\"))\n        text = self.text_embedder(prompt_embeds)\n        emb = torch.concat([image, text], dim=1)\n        for block_id, block in enumerate(self.blocks):\n            emb = gradient_checkpoint_forward(\n                block,\n                use_gradient_checkpointing=use_gradient_checkpointing,\n                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n                emb=emb,\n                pos_emb=pos_emb,\n                t_emb=t_emb,\n            )\n        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]\n        emb = self.proj_out(emb)\n        emb = rearrange(emb, \"B (H W) C -> B C H W\", W=latents.shape[-1])\n        return emb\n\n\nclass AAAImagePipeline(BasePipeline):\n    def __init__(self, device=\"cuda\", torch_dtype=torch.bfloat16):\n        super().__init__(\n            device=device, torch_dtype=torch_dtype,\n            height_division_factor=16, width_division_factor=16,\n        )\n        self.scheduler = FlowMatchScheduler(\"FLUX.2\")\n        self.text_encoder: ZImageTextEncoder = None\n        self.dit: AAADiT = None\n        self.vae: Flux2VAE = None\n        self.tokenizer: AutoProcessor = None\n        self.in_iteration_models = (\"dit\",)\n        self.units = [\n            AAAUnit_PromptEmbedder(),\n            AAAUnit_NoiseInitializer(),\n            AAAUnit_InputImageEmbedder(),\n        ]\n        self.model_fn = model_fn_aaa\n    \n    @staticmethod\n    def from_pretrained(\n        torch_dtype: torch.dtype = torch.bfloat16,\n        device: Union[str, torch.device] = \"cuda\",\n        model_configs: list[ModelConfig] = [],\n        tokenizer_config: ModelConfig = None,\n        vram_limit: float = None,\n    ):\n        # Initialize pipeline\n        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)\n        model_pool = pipe.download_and_load_models(model_configs, vram_limit)\n        \n        # Fetch models\n        pipe.text_encoder = model_pool.fetch_model(\"z_image_text_encoder\")\n        pipe.dit = model_pool.fetch_model(\"aaa_dit\")\n        pipe.vae = model_pool.fetch_model(\"flux2_vae\")\n        if tokenizer_config is not None:\n            tokenizer_config.download_if_necessary()\n            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)\n        \n        # VRAM Management\n        pipe.vram_management_enabled = pipe.check_vram_management_state()\n        return pipe\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        # Prompt\n        prompt: str,\n        negative_prompt: str = \"\",\n        cfg_scale: float = 1.0,\n        # Image\n        input_image: Image.Image = None,\n        denoising_strength: float = 1.0,\n        # Shape\n        height: int = 1024,\n        width: int = 1024,\n        # Randomness\n        seed: int = None,\n        rand_device: str = \"cpu\",\n        # Steps\n        num_inference_steps: int = 30,\n        # Progress bar\n        progress_bar_cmd = tqdm,\n    ):\n        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)\n\n        # Parameters\n        inputs_posi = {\"prompt\": prompt}\n        inputs_nega = {\"negative_prompt\": negative_prompt}\n        inputs_shared = {\n            \"cfg_scale\": cfg_scale,\n            \"input_image\": input_image, \"denoising_strength\": denoising_strength,\n            \"height\": height, \"width\": width,\n            \"seed\": seed, \"rand_device\": rand_device,\n            \"num_inference_steps\": num_inference_steps,\n        }\n        for unit in self.units:\n            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)\n\n        # Denoise\n        self.load_models_to_device(self.in_iteration_models)\n        models = {name: getattr(self, name) for name in self.in_iteration_models}\n        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):\n            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)\n            noise_pred = self.cfg_guided_model_fn(\n                self.model_fn, cfg_scale,\n                inputs_shared, inputs_posi, inputs_nega,\n                **models, timestep=timestep, progress_id=progress_id\n            )\n            inputs_shared[\"latents\"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)\n        \n        # Decode\n        self.load_models_to_device(['vae'])\n        image = self.vae.decode(inputs_shared[\"latents\"])\n        image = self.vae_output_to_image(image)\n        self.load_models_to_device([])\n\n        return image\n\n\nclass AAAUnit_PromptEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            seperate_cfg=True,\n            input_params_posi={\"prompt\": \"prompt\"},\n            input_params_nega={\"prompt\": \"negative_prompt\"},\n            output_params=(\"prompt_embeds\",),\n            onload_model_names=(\"text_encoder\",)\n        )\n        self.hidden_states_layers = (-1,)\n\n    def process(self, pipe: AAAImagePipeline, prompt):\n        pipe.load_models_to_device(self.onload_model_names)\n        text = pipe.tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            tokenize=False,\n            add_generation_prompt=True,\n            enable_thinking=False,\n        )\n        inputs = pipe.tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128).to(pipe.device)\n        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)\n        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)\n        return {\"prompt_embeds\": prompt_embeds}\n\n\nclass AAAUnit_NoiseInitializer(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"height\", \"width\", \"seed\", \"rand_device\"),\n            output_params=(\"noise\",),\n        )\n\n    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):\n        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)\n        return {\"noise\": noise}\n\n\nclass AAAUnit_InputImageEmbedder(PipelineUnit):\n    def __init__(self):\n        super().__init__(\n            input_params=(\"input_image\", \"noise\"),\n            output_params=(\"latents\", \"input_latents\"),\n            onload_model_names=(\"vae\",)\n        )\n\n    def process(self, pipe: AAAImagePipeline, input_image, noise):\n        if input_image is None:\n            return {\"latents\": noise, \"input_latents\": None}\n        pipe.load_models_to_device(['vae'])\n        image = pipe.preprocess_image(input_image)\n        input_latents = pipe.vae.encode(image)\n        if pipe.scheduler.training:\n            return {\"latents\": noise, \"input_latents\": input_latents}\n        else:\n            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])\n            return {\"latents\": latents, \"input_latents\": input_latents}\n\n\ndef model_fn_aaa(\n    dit: AAADiT,\n    latents=None,\n    prompt_embeds=None,\n    timestep=None,\n    use_gradient_checkpointing=False,\n    use_gradient_checkpointing_offload=False,\n    **kwargs,\n):\n    model_output = dit(\n        latents,\n        prompt_embeds,\n        timestep,\n        use_gradient_checkpointing=use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,\n    )\n    return model_output\n\n\nclass AAATrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        super().__init__()\n        self.pipe = AAAImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"model.safetensors\"),\n                ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n        )\n        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)\n        self.pipe.freeze_except([\"dit\"])\n        self.pipe.scheduler.set_timesteps(1000, training=True)\n\n    def forward(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            \"cfg_scale\": 1,\n            \"use_gradient_checkpointing\": False,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n\n\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)\n    dataset = UnifiedDataset(\n        base_path=\"data/images\",\n        metadata_path=\"data/metadata_merged.csv\",\n        max_data_items=10000000,\n        data_file_keys=(\"image\",),\n        main_data_operator=UnifiedDataset.default_image_operator(base_path=\"data/images\", height=256, width=256)\n    )\n    model = AAATrainingModule(device=accelerator.device)\n    model_logger = ModelLogger(\n        \"models/AAA/v1\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=2e-4,\n        num_workers=4,\n        save_steps=50000,\n        num_epochs=999999,\n    )"
  },
  {
    "path": "docs/zh/Training/Differential_LoRA.md",
    "content": "# 差分 LoRA 训练\n\n差分 LoRA 训练是一种特殊的 LoRA 训练方式，旨在让模型学习图像之间的差异。\n\n## 训练方案\n\n我们未能找到差分 LoRA 训练最早由谁提出，这一技术已经在开源社区中流传甚久。\n\n假设我们有两张内容相似的图像：图 1 和图 2。例如两张图中分别有一辆车，但图 1 中画面细节更少，图 2 中画面细节更多。在差分 LoRA 训练中，我们进行两步训练：\n\n* 以图 1 为训练数据，以[标准监督训练](../Training/Supervised_Fine_Tuning.md)的方式，训练 LoRA 1\n* 以图 2 为训练数据，将 LoRA 1 融入基础模型后，以[标准监督训练](../Training/Supervised_Fine_Tuning.md)的方式，训练 LoRA 2\n\n在第一步训练中，由于训练数据仅有一张图，LoRA 模型很容易过拟合，因此训练完成后，LoRA 1 会让模型毫不犹豫地生成图 1，无论随机种子是什么。在第二步训练中，LoRA 模型再次过拟合，因此训练完成后，在 LoRA 1 和 LoRA 2 的共同作用下，模型会毫不犹豫地生成图 2。简言之：\n\n* LoRA 1 = 生成图 1\n* LoRA 1 + LoRA 2 = 生成图 2\n\n此时丢弃 LoRA 1，只使用 LoRA 2，模型将会理解图 1 和图 2 的差异，使生成的内容倾向于“更不像图1，更像图 2”。\n\n单一训练数据可以保证模型能够过拟合到训练数据上，但稳定性不足。为了提高稳定性，我们可以用多个图像对（image pairs）进行训练，并将训练出的 LoRA 2 进行平均，得到效果更稳定的 LoRA。\n\n用这一训练方案，可以训练出一些功能奇特的 LoRA 模型。例如，使用丑陋的和漂亮的图像对，训练提升图像美感的 LoRA；使用细节少的和细节丰富的图像对，训练增加图像细节的 LoRA。\n\n## 模型效果\n\n我们用差分 LoRA 训练技术训练了几个美学提升 LoRA，可前往对应的模型页面查看生成效果。\n\n* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1)\n* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)\n\n## 在训练框架中使用差分 LoRA 训练\n\n第一步的训练与普通 LoRA 训练没有任何差异，在第二步的训练命令中，通过 `--preset_lora_path` 参数填入第一步的 LoRA 模型文件路径，并将 `--preset_lora_model` 设置为与 `lora_base_model` 相同的参数，即可将 LoRA 1 加载到基础模型中。\n\n## 框架设计思路\n\n在训练框架中，`--preset_lora_path` 指向的模型在 `DiffusionTrainingModule` 的 `switch_pipe_to_training_mode` 中完成加载。\n"
  },
  {
    "path": "docs/zh/Training/Direct_Distill.md",
    "content": "# 端到端的蒸馏加速训练\n\n## 蒸馏加速训练\n\nDiffusion 模型的推理过程通常需要多步迭代，在提升生成效果的同时也让生成过程变得缓慢。通过蒸馏加速训练，可以减少生成清晰内容所需的步数。蒸馏加速训练技术的本质训练目标是让少量步数的生成效果与大量步数的生成效果对齐。\n\n蒸馏加速训练的方法是多样的，例如\n\n* 对抗式训练 ADD（Adversarial Diffusion Distillation）\n    * 论文：https://arxiv.org/abs/2311.17042\n    * 模型：[stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo)\n* 渐进式训练 Hyper-SD\n    * 论文：https://arxiv.org/abs/2404.13686\n    * 模型：[ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD)\n\n## 直接蒸馏\n\n在训练框架层面，支持这类蒸馏加速训练方案是极其困难的。在训练框架的设计中，我们需要保证训练方案满足以下条件：\n\n* 通用性：训练方案适用于大多数框架内支持的 Diffusion 模型，而非只能对某个特定模型生效，这是代码框架建设的基本要求。\n* 稳定性：训练方案需保证训练效果稳定，不需要人工进行精细的参数调整，ADD 中的对抗式训练则无法保证稳定性。\n* 简洁性：训练方案不会引入额外的复杂模块，根据奥卡姆剃刀（[Occam's Razor](https://en.wikipedia.org/wiki/Occam%27s_razor)）原理，复杂解决方案可能引入潜在风险，Hyper-SD 中的 Human Feedback Learning 让训练过程变得过于复杂。\n\n因此，在 `DiffSynth-Studio` 的训练框架中，我们设计了一个端到端的蒸馏加速训练方案，我们称为直接蒸馏（Direct Distill），其训练过程的伪代码如下：\n\n```\nseed = xxx\nwith torch.no_grad():\n    image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)\nimage_2 = pipe(prompt, steps=4, seed=seed, cfg=1)\nloss = torch.nn.functional.mse_loss(image_1, image_2)\n```\n\n是的，非常端到端的训练方案，稍加训练就可以有立竿见影的效果。\n\n## 直接蒸馏训练的模型\n\n我们用这个方案基于 Qwen-Image 训练了两个模型：\n\n* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): 全量蒸馏训练\n* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA 蒸馏训练\n\n点击模型链接即可前往模型页面查看模型效果。\n\n## 在训练框架中使用蒸馏加速训练\n\n首先，需要生成训练数据，请参考[模型推理](../Pipeline_Usage/Model_Inference.md)部分编写推理代码，以足够多的推理步数生成训练数据。\n\n以 Qwen-Image 为例，以下代码可以生成一张图片：\n\n```python\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n```\n\n然后，我们把必要的信息编写成[元数据文件](../API_Reference/core/data.md#元数据)：\n\n```csv\nimage,prompt,seed,rand_device,num_inference_steps,cfg_scale\ndistill_qwen/image.jpg,\"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\",0,cpu,4,1\n```\n\n这个样例数据集可以直接下载：\n\n```shell\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset\n```\n\n然后开始 LoRA 蒸馏加速训练：\n\n```shell\nbash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh\n```\n\n请注意，在[训练脚本参数](../Pipeline_Usage/Model_Training.md#脚本参数)中，数据集的图像分辨率设置要避免触发缩放处理。当设定 `--height` 和 `--width` 以启用固定分辨率时，所有训练数据必须是以完全一致的宽高生成的；当设定 `--max_pixels` 以启用动态分辨率时，`--max_pixels` 的数值必须大于或等于任一训练图像的像素面积。\n\n## 训练框架设计思路\n\n直接蒸馏与[标准监督训练](../Training/Supervised_Fine_Tuning.md)相比，仅训练的损失函数不同，直接蒸馏的损失函数是 `diffsynth.diffusion.loss` 中的 `DirectDistillLoss`。\n\n## 未来工作\n\n直接蒸馏是通用性很强的加速方案，但未必是效果最好的方案，所以我们暂未把这一技术以论文的形式发布。我们希望把这个问题交给学术界和开源社区共同解决，期待开发者能够给出更完善的通用训练方案。\n"
  },
  {
    "path": "docs/zh/Training/FP8_Precision.md",
    "content": "# 在训练中启用 FP8 精度\n\n尽管 `DiffSynth-Studio` 在模型推理中支持[显存管理](../Pipeline_Usage/VRAM_management.md)，但其中的大部分减少显存占用的技术不适合用于训练中，Offload 会导致极为缓慢的训练过程。\n\nFP8 精度是唯一可在训练过程中启用的显存管理策略，但本框架目前不支持原生 FP8 精度训练，原因详见 [Q&A: 为什么训练框架不支持原生 FP8 精度训练？](../QA.md#为什么训练框架不支持原生-fp8-精度训练)，仅支持将参数不被梯度更新的模型（不需要梯度回传，或梯度仅更新其 LoRA）以 FP8 精度进行存储。\n\n## 启用 FP8\n\n在我们提供的训练脚本中，通过参数 `--fp8_models` 即可快速设置以 FP8 精度存储的模型。以 Qwen-Image 的 LoRA 训练为例，我们提供了启用 FP8 训练的脚本，位于 [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh)。训练完成后，可通过脚本 [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py) 验证训练效果。\n\n请注意，这种 FP8 显存管理策略不支持梯度更新，当某个模型被设置为可训练时，不能为这个模型开启 FP8 精度，支持开启 FP8 的模型包括两类：\n\n* 参数不可训练，例如 VAE 模型\n* 梯度不更新其参数，例如 LoRA 训练中的 DiT 模型\n\n经实验验证，开启 FP8 后的 LoRA 训练效果没有明显的图像质量下降，但理论上误差是确实存在的，如果在使用本功能时遇到训练效果不如 BF16 精度训练的问题，请通过 GitHub issue 给我们提供反馈。\n\n## 训练框架设计思路\n\n训练框架完全沿用推理的显存管理，在训练中仅通过 `DiffusionTrainingModule` 中的 `parse_model_configs` 解析显存管理配置。\n"
  },
  {
    "path": "docs/zh/Training/Split_Training.md",
    "content": "# 两阶段拆分训练\n\n本文档介绍拆分训练，能够自动将训练过程拆分为两阶段进行，减少显存占用，同时加快训练速度。\n\n（拆分训练是实验性特性，尚未进行大规模验证，如果在使用中出现问题，请在 GitHub 上给我们提 issue。）\n\n## 拆分训练\n\n在大部分模型的训练过程中，大量计算发生在“前处理”中，即“与去噪模型无关的计算”，包括 VAE 编码、文本编码等。当对应的模型参数固定时，这部分计算的结果是重复的，在多个 epoch 中每个数据样本的计算结果完全相同，因此我们提供了“拆分训练”功能，该功能可以自动分析并拆分训练过程。\n\n对于普通文生图模型的标准监督训练，拆分过程是非常简单的，只需要把所有 [`Pipeline Units`](../Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段，将计算结果存储到硬盘中，然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传，情况就变得极其复杂，为此，我们引入了一个计算图拆分算法用于分析如何拆分计算。\n\n## 计算图拆分算法\n\n> （我们会在后续的文档更新中补充计算图拆分算法的详细细节）\n\n## 使用拆分训练\n\n拆分训练已支持[标准监督训练](../Training/Supervised_Fine_Tuning.md)和[直接蒸馏训练](../Training/Direct_Distill.md)，在训练命令中通过 `--task` 参数控制，以 Qwen-Image 模型的 LoRA 训练为例，拆分前的训练命令为：\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n```\n\n拆分后，在第一阶段中，做如下修改：\n\n* 将 `--dataset_repeat` 改为 1，避免重复计算\n* 将 `--output_path` 改为第一阶段计算结果保存的路径\n* 添加额外参数 `--task \"sft:data_process\"`\n* 删除 `--model_id_with_origin_paths` 中的 DiT 模型\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/example_image_dataset \\\n  --dataset_metadata_path data/example_image_dataset/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:data_process\"\n```\n\n在第二阶段，做如下修改：\n\n* 将 `--dataset_base_path` 改为第一阶段的 `--output_path`\n* 删除 `--dataset_metadata_path`\n* 添加额外参数 `--task \"sft:train\"`\n* 删除 `--model_id_with_origin_paths` 中的 Text Encoder 和 VAE 模型\n\n```shell\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:train\"\n```\n\n我们提供了样例训练脚本和验证脚本，位于 `examples/qwen_image/model_training/special/split_training`。\n\n## 训练框架设计思路\n\n训练框架通过 `DiffusionTrainingModule` 的 `split_pipeline_units` 方法拆分 `Pipeline` 中的计算单元。\n"
  },
  {
    "path": "docs/zh/Training/Supervised_Fine_Tuning.md",
    "content": "# 标准监督训练\n\n在理解 [Diffusion 模型基本原理](../Training/Understanding_Diffusion_models.md)之后，本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理，帮助开发者编写新的训练代码，如需使用我们提供的默认训练功能，请参考[模型训练](../Pipeline_Usage/Model_Training.md)。\n\n回顾前文中的模型训练伪代码，当我们实际编写代码时，情况会变得极为复杂。部分模型需要输入额外的引导条件并进行预处理，例如 ControlNet；部分模型需要与去噪模型进行交叉式的计算，例如 VACE；部分模型因显存需求过大，需要开启 Gradient Checkpointing，例如 Qwen-Image 的 DiT。\n\n为了实现严格的推理和训练一致性，我们对 `Pipeline` 等组件进行了抽象封装，在训练过程中大量复用推理代码。请参考[接入 Pipeline](../Developer_Guide/Building_a_Pipeline.md) 了解 `Pipeline` 组件的设计。接下来我们介绍训练框架如何利用 `Pipeline` 组件构建训练算法。\n\n## 框架设计思路\n\n训练模块在 `Pipeline` 上层进行封装，继承 `diffsynth.diffusion.training_module` 中的 `DiffusionTrainingModule`，我们需为训练模块提供必要的 `__init__` 和 `forward` 方法。我们以 Qwen-Image 的 LoRA 训练为例，在 `examples/qwen_image/model_training/special/simple/train.py` 中提供了仅包含基础训练功能的简易脚本，帮助开发者理解训练模块的设计思路。\n\n```python\nclass QwenImageTrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        # Initialize models here.\n        pass\n\n    def forward(self, data):\n        # Compute loss here.\n        return loss\n```\n\n### `__init__`\n\n在 `__init__` 中需进行模型的初始化，先加载模型，然后将其切换到训练模式。\n\n```python\n    def __init__(self, device):\n        super().__init__()\n        # Load the pipeline\n        self.pipe = QwenImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n        )\n        # Switch to training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe,\n            lora_base_model=\"dit\",\n            lora_target_modules=\"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj\",\n            lora_rank=32,\n        )\n```\n\n加载模型的逻辑与推理时基本一致，支持从远程和本地路径加载模型，详见[模型推理](../Pipeline_Usage/Model_Inference.md)，但请注意不要启用[显存管理](../Pipeline_Usage/VRAM_management.md)。\n\n`switch_pipe_to_training_mode` 可以将模型切换到训练模式，详见 `switch_pipe_to_training_mode`。\n\n### `forward`\n\n在 `forward` 中需计算损失函数值，先进行前处理，然后经过 `Pipeline` 的 [`model_fn`](../Developer_Guide/Building_a_Pipeline.md#model_fn) 计算损失函数。\n\n```python\n    def forward(self, data):\n        # Preprocess\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": True,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        # Loss\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n```\n\n前处理过程与推理阶段一致，开发者只需假定在使用 `Pipeline` 进行推理，将输入参数填入即可。\n\n损失函数的计算沿用 `diffsynth.diffusion.loss` 中的 `FlowMatchSFTLoss`。\n\n### 开始训练\n\n训练框架还需其他模块，包括：\n\n* accelerator: `accelerate` 提供的训练启动器，详见 [`accelerate`](https://huggingface.co/docs/accelerate/index)\n* dataset: 通用数据集，详见 [`diffsynth.core.data`](../API_Reference/core/data.md)\n* model_logger: 模型记录器，详见 `diffsynth.diffusion.logger`\n\n```python\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)],\n    )\n    dataset = UnifiedDataset(\n        base_path=\"data/example_image_dataset\",\n        metadata_path=\"data/example_image_dataset/metadata.csv\",\n        repeat=50,\n        data_file_keys=\"image\",\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=\"data/example_image_dataset\",\n            height=512,\n            width=512,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = QwenImageTrainingModule(accelerator.device)\n    model_logger = ModelLogger(\n        output_path=\"models/toy_model\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=1e-5, num_epochs=1,\n    )\n```\n\n将以上所有代码组装，得到 `examples/qwen_image/model_training/special/simple/train.py`。使用以下命令即可启动训练：\n\n```\naccelerate launch examples/qwen_image/model_training/special/simple/train.py\n```\n"
  },
  {
    "path": "docs/zh/Training/Understanding_Diffusion_models.md",
    "content": "# Diffusion 模型基本原理\n\n本文介绍 Diffusion 模型的基本原理，帮助你理解训练框架是如何构建的。为了让读者更轻松地理解这些复杂的数学理论，我们重构了 Diffusion 模型的理论框架，抛弃了复杂的随机微分方程，用一种更简洁易懂的形式进行介绍。\n\n## 引言\n\nDiffusion 模型通过多步迭代式地去噪（denoise）生成清晰的图像或视频内容，我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地，在完整的一轮 denoise 过程中，我们从随机高斯噪声 $x_T$ 开始，通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\\cdots$，在每一步中逐渐减少噪声含量，最终得到不含噪声的数据样本 $x_0$。\n\n![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)\n\n这个过程是很直观的，但如果要理解其中的细节，我们就需要回答这几个问题：\n\n* 每一步的噪声含量是如何定义的？\n* 迭代去噪的计算是如何进行的？\n* 如何训练这样的 Diffusion 模型？\n* 现代 Diffusion 模型的架构是什么样的？\n* 本项目如何封装和实现模型训练？\n\n## 每一步的噪声含量是如何定义的？\n\n在 Diffusion 模型的理论体系中，噪声的含量是由一系列参数 $\\sigma_T$、$\\sigma_{T-1}$、$\\sigma_{T-2}$、$\\cdots$、$\\sigma_0$ 决定的。其中\n\n* $\\sigma_T=1$，对应的 $x_T$ 为纯粹的高斯噪声\n* $\\sigma_T>\\sigma_{T-1}>\\sigma_{T-2}>\\cdots>x_0$，在迭代过程中噪声含量逐渐减小\n* $\\sigma_0=0$，对应的 $x_0$ 为不含任何噪声的数据样本\n\n至于中间 $\\sigma_{T-1}$、$\\sigma_{T-2}$、$\\cdots$、$\\sigma_1$ 的数值，则不是固定的，满足递减的条件即可。\n\n那么在中间的某一步，我们可以直接合成含噪声的数据样本 $x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$。\n\n![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)\n\n## 迭代去噪的计算是如何进行的？\n\n在理解迭代去噪的计算前，我们要先搞清楚，去噪模型的输入和输出是什么。我们把模型抽象成一个符号 $\\hat \\epsilon$，它的输入通常包含三部分\n\n* 时间步 $t$，模型需要理解当前处于去噪过程的哪个阶段\n* 含噪声的数据样本 $x_t$，模型需要理解要对什么数据进行去噪\n* 引导条件 $c$，模型需要理解要通过去噪生成什么样的数据样本\n\n其中，引导条件 $c$ 是新引入的参数，它是由用户输入的，可以是用于描述图像内容的文本，也可以是用于勾勒图像结构的线稿图。\n\n而模型的输出 $\\hat \\epsilon(x_t,c,t)$，则近似地等于 $x_T-x_0$，也就是整个扩散过程（去噪过程的反向过程）的方向。\n\n接下来我们分析一步迭代中发生的计算，在时间步 $t$，模型通过计算得到近似的 $x_T-x_0$ 后，我们计算下一步的 $x_{t-1}$：\n\n$$\n\\begin{aligned}\nx_{t-1}&=x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot \\hat \\epsilon(x_t,c,t)\\\\\n&\\approx x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot (x_T-x_0)\\\\\n&=(1-\\sigma_t)x_0+\\sigma_t x_T + (\\sigma_{t-1} - \\sigma_t) \\cdot (x_T-x_0)\\\\\n&=(1-\\sigma_{t-1})x_0+\\sigma_{t-1}x_T\n\\end{aligned}\n$$\n\n完美！与时间步 $t-1$ 时的噪声含量定义完美契合。\n\n> （这部分可能有点难懂，请不必担心，首次阅读本文时建议跳过这部分，不影响后文的阅读。）\n>\n> 完成了这段有点复杂的公式推导后，我们思考一个问题，为什么模型的输出要近似地等于 $x_T-x_0$ 呢？可以设定成其他值吗？\n>\n> 实际上，Diffusion 模型依赖两个定义形成完备的理论。在以上的公式中，我们可以提炼出这两个定义，并导出迭代公式：\n>\n> * 数据定义：$x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$\n> * 模型定义：$\\hat \\epsilon(x_t,c,t)=x_T-x_0$\n> * 导出迭代公式：$x_{t-1}=x_t + (\\sigma_{t-1} - \\sigma_t) \\cdot \\hat \\epsilon(x_t,c,t)$\n>\n> 这三个数学公式是完备的，例如在刚才的推导中，我们把数据定义和模型定义代入迭代公式，可以得到与数据定义吻合的 $x_{t-1}$。\n>\n> 这是基于 Flow Matching 理论构建的两个定义，但 Diffusion 模型也可用其他的两个定义来实现，例如早期基于 DDPM（Denoising Diffusion Probabilistic Models）的模型，其两个定义及导出的迭代公式为：\n>\n> * 数据定义：$x_t=\\sqrt{\\alpha_t}x_0+\\sqrt{1-\\alpha_t}x_T$\n> * 模型定义：$\\hat \\epsilon(x_t,c,t)=x_T$\n> * 导出迭代公式：$x_{t-1}=\\sqrt{\\alpha_{t-1}}\\left(\\frac{x_t-\\sqrt{1-\\alpha_t}\\hat \\epsilon(x_t,c,t)}{\\sqrt{\\sigma_t}}\\right)+\\sqrt{1-\\alpha_{t-1}}\\hat \\epsilon(x_t,c,t)$\n>\n> 更一般地，我们用矩阵描述迭代公式的导出过程，对于任意数据定义和模型定义，有：\n>\n> * 数据定义：$x_t=C_T(x_0,x_T)^T$\n> * 模型定义：$\\hat \\epsilon(x_t,c,t)=C_T^{[\\epsilon]}(x_0,x_T)^T$\n> * 导出迭代公式：$x_{t-1}=C_{t-1}(C_t,C_t^{[\\epsilon]})^{-T}(x_t,\\hat \\epsilon(x_t,c,t))^T$\n>\n> 其中，$C_t$、$C_t^{[\\epsilon]}$ 是 $1\\times 2$ 的系数矩阵，不难发现，在构造两个定义时，需保证矩阵 $(C_t,C_t^{[\\epsilon]})^T$ 是可逆的。\n>\n> 尽管 Flow Matching 与 DDPM 已被大量预训练模型广泛验证过，但这并不代表这是最优的方案，我们鼓励开发者设计新的 Diffusion 模型理论实现更好的训练效果。\n\n## 如何训练这样的 Diffusion 模型？\n\n搞清楚迭代去噪的过程之后，接下来我们考虑如何训练这样的 Diffusion 模型。\n\n训练过程不同于生成过程，如果我们在训练过程中保留多步迭代，那么梯度需经过多步回传，带来的时间和空间复杂度是灾难性的。为了提高计算效率，我们在训练中随机选择某一时间步 $t$ 进行训练。\n\n以下是训练过程的伪代码\n\n> 从数据集获取数据样本 $x_0$ 和引导条件 $c$\n>\n> 随机采样时间步 $t\\in(0,T]$\n>\n> 随机采样高斯噪声 $x_T\\in \\mathcal N(O,I)$\n>\n> $x_t=(1-\\sigma_t)x_0+\\sigma_t x_T$\n>\n> $\\hat \\epsilon(x_t,c,t)$\n>\n> 损失函数 $\\mathcal L=||\\hat \\epsilon(x_t,c,t)-(x_T-x_0)||_2^2$\n>\n> 梯度回传并更新模型参数\n\n## 现代 Diffusion 模型的架构是什么样的？\n\n从理论到实践，还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟，主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构，包括数据编解码器、引导条件编码器、去噪模型三部分。\n\n![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)\n\n### 数据编解码器\n\n在前文中，我们一直将 $x_0$ 称为“数据样本”，而不是图像或视频，这是因为现代 Diffusion 模型通常不会直接在图像或视频上进行处理，而是用编码器（Encoder）-解码器（Decoder）架构的模型，通常是 VAE（Variational Auto-Encoders）模型，将图像或视频编码为 Embedding 张量，得到 $x_0$。\n\n数据经过编码器编码后，再经过解码器解码，重建后的内容与原来近似地一致，会有少量误差。那么，为什么要在编码后的 Embedding 张量上处理，而不是在图像或视频上直接处理呢？主要原因有亮点：\n\n* 编码的同时对数据进行了压缩，编码后处理的计算量更小。\n* 编码后的数据分布与高斯分布更相似，更容易用去噪模型对数据进行建模。\n\n在生成过程中，编码器部分不参与计算，迭代完成后，用解码器部分解码 $x_0$ 即可得到清晰的图像或视频。在训练过程中，解码器部分不参与计算，仅编码器用于计算 $x_0$。\n\n### 引导条件编码器\n\n用户输入的引导条件 $c$ 可能是复杂多样的，需要由专门的编码器模型将其处理成 Embedding 张量。按照引导条件的类型，我们把引导条件编码器分为以下几类：\n\n* 文本类型，例如 CLIP、Qwen-VL\n* 图像类型，例如 ControlNet、IP-Adapter\n* 视频类型，例如 VAE\n\n> 前文中的模型 $\\hat \\epsilon$ 指代此处的所有引导条件编码器和去噪模型这一整体，我们把引导条件编码器单独拆分列出，因为这类模型在 Diffusion 训练中通常是冻结的，且输出值与时间步 $t$ 无关，因此引导条件编码器的计算可以离线进行。\n\n### 去噪模型\n\n去噪模型是 Diffusion 模型真正的本体，其模型结构多种多样，例如 UNet、DiT，模型开发者可在此结构上自由发挥。\n\n## 本项目如何封装和实现模型训练？\n\n请阅读下一文档：[标准监督训练](../Training/Supervised_Fine_Tuning.md)\n"
  },
  {
    "path": "docs/zh/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\n# import sphinx_book_theme\n\nsys.path.insert(0, os.path.abspath('../../'))\n# -- Project information -----------------------------------------------------\n\nproject = 'diffsynth'\ncopyright = '2022-2025, Alibaba ModelScope'\nauthor = 'ModelScope Authors'\nversion_file = '../../diffsynth/version.py'\nhtml_theme = 'sphinx_rtd_theme'\nlanguage = 'zh_CN'\n\n\ndef get_version():\n    with open(version_file, 'r', encoding='utf-8') as f:\n        exec(compile(f.read(), version_file, 'exec'))\n    return locals()['__version__']\n\n\n# The full version, including alpha/beta/rc tags\nversion = get_version()\nrelease = version\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.viewcode',\n    'sphinx_markdown_tables',\n    'sphinx_copybutton',\n    \"sphinx_rtd_theme\",\n    'sphinx.ext.mathjax',\n    'myst_parser',\n]\n# build the templated autosummary files\nautosummary_generate = True\nnumpydoc_show_class_members = False\n\n# Enable overriding of function signatures in the first line of the docstring.\nautodoc_docstring_signature = True\n\n# Disable docstring inheritance\nautodoc_inherit_docstrings = False\n\n# Show type hints in the description\nautodoc_typehints = 'description'\n\n# Add parameter types if the parameter is documented in the docstring\nautodoc_typehints_description_target = 'documented_params'\n\nautodoc_default_options = {\n    'member-order': 'bysource',\n}\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = ['.rst', '.md']\n\n# The master toctree document.\nroot_doc = 'index'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['build']\n# A list of glob-style patterns [1] that are used to find source files.\n# They are matched against the source file names relative to the source directory,\n# using slashes as directory separators on all platforms.\n# The default is **, meaning that all files are recursively included from the source directory.\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\n# html_theme = 'sphinx_book_theme'\n# html_theme_path = [sphinx_book_theme.get_html_theme_path()]\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n# html_css_files = ['css/readthedocs.css']\n\n# -- Options for HTMLHelp output ---------------------------------------------\n# Output file base name for HTML help builder.\n\n# -- Extension configuration -------------------------------------------------\n# Ignore >>> when copying code\ncopybutton_prompt_text = r'>>> |\\.\\.\\. '\ncopybutton_prompt_is_regexp = True\n\n# Example configuration for intersphinx: refer to the Python standard library.\nintersphinx_mapping = {'https://docs.python.org/': None}\n\nmyst_enable_extensions = [\n    'amsmath',\n    'dollarmath',\n    'colon_fence',\n]\n"
  },
  {
    "path": "docs/zh/index.rst",
    "content": "欢迎来到 DiffSynth-Studio 的文档\n=====================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 文档介绍\n\n   README\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 上手使用\n\n   Pipeline_Usage/Setup\n   Pipeline_Usage/Model_Inference\n   Pipeline_Usage/VRAM_management\n   Pipeline_Usage/Model_Training\n   Pipeline_Usage/Environment_Variables\n   Pipeline_Usage/GPU_support\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 模型详解\n\n   Model_Details/FLUX\n   Model_Details/Wan\n   Model_Details/Qwen-Image\n   Model_Details/FLUX2\n   Model_Details/Z-Image\n   Model_Details/Anima\n   Model_Details/LTX-2\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 训练框架\n\n   Training/Understanding_Diffusion_models\n   Training/Supervised_Fine_Tuning\n   Training/FP8_Precision\n   Training/Direct_Distill\n   Training/Split_Training\n   Training/Differential_LoRA\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 模型接入\n\n   Developer_Guide/Integrating_Your_Model\n   Developer_Guide/Building_a_Pipeline\n   Developer_Guide/Enabling_VRAM_management\n   Developer_Guide/Training_Diffusion_Models\n\n.. toctree::\n   :maxdepth: 2\n   :caption: API 参考\n\n   API_Reference/core/attention\n   API_Reference/core/data\n   API_Reference/core/gradient\n   API_Reference/core/loader\n   API_Reference/core/vram\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 学术导引\n\n   Research_Tutorial/train_from_scratch\n   Research_Tutorial/inference_time_scaling\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 常见问题\n\n   QA\n\nIndices and tables\n==================\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "examples/anima/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/Anima.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/Anima.html\n"
  },
  {
    "path": "examples/anima/model_inference/anima-preview.py",
    "content": "from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\n\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\")\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/anima/model_inference_low_vram/anima-preview.py",
    "content": "from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\", **vram_config),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\nimage = pipe(prompt, seed=0, num_inference_steps=50)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/anima/model_training/full/anima-preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"anima/anima-preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/anima/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/anima/anima-preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/anima/anima-preview/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors\" \\\n  --tokenizer_path \"Qwen/Qwen3-0.6B:./\" \\\n  --tokenizer_t5xxl_path \"stabilityai/stable-diffusion-3.5-large:tokenizer_3/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/anima-preview_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/anima/model_training/lora/anima-preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"anima/anima-preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/anima/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/anima/anima-preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/anima/anima-preview/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors\" \\\n  --tokenizer_path \"Qwen/Qwen3-0.6B:./\" \\\n  --tokenizer_t5xxl_path \"stabilityai/stable-diffusion-3.5-large:tokenizer_3/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/anima-preview_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/anima/model_training/train.py",
    "content": "import torch, os, argparse, accelerate\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass AnimaTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None, tokenizer_t5xxl_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n    ):\n        super().__init__()\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"))\n        tokenizer_t5xxl_config = self.parse_path_or_model_id(tokenizer_t5xxl_path, ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\"))\n        self.pipe = AnimaImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_t5xxl_config=tokenizer_t5xxl_config)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef anima_parser():\n    parser = argparse.ArgumentParser(description=\"Training script for Anima models.\")\n    parser = add_general_config(parser)\n    parser = add_image_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--tokenizer_t5xxl_path\", type=str, default=None, help=\"Path to tokenizer_t5xxl.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = anima_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = AnimaTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        tokenizer_t5xxl_path=args.tokenizer_t5xxl_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=accelerator.device,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)"
  },
  {
    "path": "examples/anima/model_training/validate_full/anima-preview.py",
    "content": "from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\")\n)\nstate_dict = load_state_dict(\"./models/train/anima-preview_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/anima/model_training/validate_lora/anima-preview.py",
    "content": "from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig\nimport torch\n\n\npipe = AnimaImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/diffusion_models/anima-preview.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/text_encoders/qwen_3_06b_base.safetensors\"),\n        ModelConfig(model_id=\"circlestone-labs/Anima\", origin_file_pattern=\"split_files/vae/qwen_image_vae.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen3-0.6B\", origin_file_pattern=\"./\"),\n    tokenizer_t5xxl_config=ModelConfig(model_id=\"stabilityai/stable-diffusion-3.5-large\", origin_file_pattern=\"tokenizer_3/\")\n)\npipe.load_lora(pipe.dit, \"./models/train/anima-preview_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/dev_tools/fix_path.py",
    "content": "import re, os\n\n\ndef read_file(path):\n    with open(path, \"r\", encoding=\"utf-8-sig\") as f:\n        context = f.read()\n    return context\n\ndef get_files(files, path):\n    if os.path.isdir(path):\n        for folder in os.listdir(path):\n            get_files(files, os.path.join(path, folder))\n    elif path.endswith(\".md\"):\n        files.append(path)\n        \ndef fix_path(doc_root_path):\n    files = []\n    get_files(files, doc_root_path)\n    file_map = {}\n    for file in files:\n        name = file.split(\"/\")[-1]\n        file_map[name] = \"/\" + file\n\n    pattern = re.compile(r'\\]\\([^)]*\\.md')\n    for file in files:\n        context = read_file(file)\n        matches = pattern.findall(context)\n        \n        edited = False\n        for match in matches:\n            target = \"](\" + file_map[match.split(\"/\")[-1].replace(\"](\", \"\")]\n            context = context.replace(match, target)\n            if target != match:\n                print(match, target)\n                edited = True\n            print(file, match, target)\n        \n        if edited:\n            with open(file, \"w\", encoding=\"utf-8\") as f:\n                f.write(context)\n\nfix_path(\"doc/zh\")\nfix_path(\"doc/en\")"
  },
  {
    "path": "examples/dev_tools/unit_test.py",
    "content": "import os, shutil, multiprocessing, time\nNUM_GPUS = 7\n\n\ndef script_is_processed(output_path, script):\n    return os.path.exists(os.path.join(output_path, script)) and \"log.txt\" in os.listdir(os.path.join(output_path, script))\n\n\ndef filter_unprocessed_tasks(script_path):\n    tasks = []\n    output_path = os.path.join(\"data\", script_path)\n    for script in sorted(os.listdir(script_path)):\n        if not script.endswith(\".sh\") and not script.endswith(\".py\"):\n            continue\n        if script_is_processed(output_path, script):\n            continue\n        tasks.append(script)\n    return tasks\n\n\ndef run_inference(script_path):\n    tasks = filter_unprocessed_tasks(script_path)\n    output_path = os.path.join(\"data\", script_path)\n    for script in tasks:\n        source_path = os.path.join(script_path, script)\n        target_path = os.path.join(output_path, script)\n        os.makedirs(target_path, exist_ok=True)\n        cmd = f\"python {source_path} > {target_path}/log.txt 2>&1\"\n        print(cmd, flush=True)\n        os.system(cmd)\n        for file_name in os.listdir(\"./\"):\n            if file_name.endswith(\".jpg\") or file_name.endswith(\".png\") or file_name.endswith(\".mp4\"):\n                shutil.move(file_name, os.path.join(target_path, file_name))\n\n\ndef run_tasks_on_single_GPU(script_path, tasks, gpu_id, num_gpu):\n    output_path = os.path.join(\"data\", script_path)\n    for script_id, script in enumerate(tasks):\n        if script_id % num_gpu != gpu_id:\n            continue\n        source_path = os.path.join(script_path, script)\n        target_path = os.path.join(output_path, script)\n        os.makedirs(target_path, exist_ok=True)\n        if script.endswith(\".sh\"):\n            cmd = f\"CUDA_VISIBLE_DEVICES={gpu_id} bash {source_path} > {target_path}/log.txt 2>&1\"\n        elif script.endswith(\".py\"):\n            cmd = f\"CUDA_VISIBLE_DEVICES={gpu_id} python {source_path} > {target_path}/log.txt 2>&1\"\n        print(cmd, flush=True)\n        os.system(cmd)\n\n\ndef run_train_multi_GPU(script_path):\n    tasks = filter_unprocessed_tasks(script_path)\n    output_path = os.path.join(\"data\", script_path)\n    for script in tasks:\n        source_path = os.path.join(script_path, script)\n        target_path = os.path.join(output_path, script)\n        os.makedirs(target_path, exist_ok=True)\n        cmd = f\"bash {source_path} > {target_path}/log.txt 2>&1\"\n        print(cmd, flush=True)\n        os.system(cmd)\n        time.sleep(1)\n        \n\ndef run_train_single_GPU(script_path):\n    tasks = filter_unprocessed_tasks(script_path)\n    processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, NUM_GPUS)) for i in range(NUM_GPUS)]\n    for p in processes:\n        p.start()\n    for p in processes:\n        p.join()\n\n\ndef move_files(prefix, target_folder):\n    os.makedirs(target_folder, exist_ok=True)\n    os.system(f\"cp -r {prefix}* {target_folder}\")\n    os.system(f\"rm -rf {prefix}*\")\n\n\ndef test_qwen_image():\n    run_inference(\"examples/qwen_image/model_inference\")\n    run_inference(\"examples/qwen_image/model_inference_low_vram\")\n    run_train_multi_GPU(\"examples/qwen_image/model_training/full\")\n    run_inference(\"examples/qwen_image/model_training/validate_full\")\n    run_train_single_GPU(\"examples/qwen_image/model_training/lora\")\n    run_inference(\"examples/qwen_image/model_training/validate_lora\")\n    \n\ndef test_wan():\n    run_train_single_GPU(\"examples/wanvideo/model_inference\")\n    move_files(\"video_\", \"data/output/model_inference\")\n    run_train_single_GPU(\"examples/wanvideo/model_inference_low_vram\")\n    move_files(\"video_\", \"data/output/model_inference_low_vram\")\n    run_train_multi_GPU(\"examples/wanvideo/model_training/full\")\n    run_train_single_GPU(\"examples/wanvideo/model_training/validate_full\")\n    move_files(\"video_\", \"data/output/validate_full\")\n    run_train_single_GPU(\"examples/wanvideo/model_training/lora\")\n    run_train_single_GPU(\"examples/wanvideo/model_training/validate_lora\")\n    move_files(\"video_\", \"data/output/validate_lora\")\n\n\ndef test_flux():\n    run_inference(\"examples/flux/model_inference\")\n    run_inference(\"examples/flux/model_inference_low_vram\")\n    run_train_multi_GPU(\"examples/flux/model_training/full\")\n    run_inference(\"examples/flux/model_training/validate_full\")\n    run_train_single_GPU(\"examples/flux/model_training/lora\")\n    run_inference(\"examples/flux/model_training/validate_lora\")\n\n\ndef test_z_image():\n    run_inference(\"examples/z_image/model_inference\")\n    run_inference(\"examples/z_image/model_inference_low_vram\")\n    run_train_multi_GPU(\"examples/z_image/model_training/full\")\n    run_inference(\"examples/z_image/model_training/validate_full\")\n    run_train_single_GPU(\"examples/z_image/model_training/lora\")\n    run_inference(\"examples/z_image/model_training/validate_lora\")\n\n\nif __name__ == \"__main__\":\n    test_z_image()\n"
  },
  {
    "path": "examples/flux/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/FLUX.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/FLUX.html\n"
  },
  {
    "path": "examples/flux/model_inference/FLEX.2-preview.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth.utils.controlnet import Annotator\nimport numpy as np\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ostris/Flex.2-preview\", origin_file_pattern=\"Flex.2-preview.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    seed=0\n)\nimage.save(\"image_1.jpg\")\n\nmask = np.zeros((1024, 1024, 3), dtype=np.uint8)\nmask[200:400, 400:700] = 255\nmask = Image.fromarray(mask)\nmask.save(\"image_mask.jpg\")\n\ninpaint_image = image\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,\n    seed=4\n)\nimage.save(\"image_2.jpg\")\n\ncontrol_image = Annotator(\"canny\")(image)\ncontrol_image.save(\"image_control.jpg\")\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    flex_control_image=control_image,\n    seed=4\n)\nimage.save(\"image_3.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-Kontext-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Kontext-dev\", origin_file_pattern=\"flux1-kontext-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nimage_1 = pipe(\n    prompt=\"a beautiful Asian long-haired female college student.\",\n    embedded_guidance=2.5,\n    seed=1,\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_2 = pipe(\n    prompt=\"transform the style to anime style.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=2,\n)\nimage_2.save(\"image_2.jpg\")\n\nimage_3 = pipe(\n    prompt=\"let her smile.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=3,\n)\nimage_3.save(\"image_3.jpg\")\n\nimage_4 = pipe(\n    prompt=\"let the girl play basketball.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=4,\n)\nimage_4.save(\"image_4.jpg\")\n\nimage_5 = pipe(\n    prompt=\"move the girl to a park, let her sit on a chair.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=5,\n)\nimage_5.save(\"image_5.jpg\")"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-Krea-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Krea-dev\", origin_file_pattern=\"flux1-krea-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nprompt = \"An beautiful woman is riding a bicycle in a park, wearing a red dress\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n\nimage = pipe(prompt=prompt, seed=0, embedded_guidance=4.5)\nimage.save(\"flux_krea.jpg\")\n\nimage = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    seed=0, cfg_scale=2, num_inference_steps=50,\n    embedded_guidance=4.5\n)\nimage.save(\"flux_krea_cfg.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/AttriCtrl-FLUX.1-Dev\", origin_file_pattern=\"models/brightness.safetensors\")\n    ],\n)\n\nfor i in [0.1, 0.3, 0.5, 0.7, 0.9]:\n    image = pipe(prompt=\"a cat on the beach\", seed=2, value_controller_inputs=[i])\n    image.save(f\"value_control_{i}.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nimport numpy as np\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\n\nimage_1 = pipe(\n    prompt=\"a cat sitting on a chair\",\n    height=1024, width=1024,\n    seed=8, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nmask = np.zeros((1024, 1024, 3), dtype=np.uint8)\nmask[100:350, 350: -300] = 255\nmask = Image.fromarray(mask)\nmask.save(\"mask.jpg\")\n\nimage_2 = pipe(\n    prompt=\"a cat sitting on a chair, wearing sunglasses\",\n    controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],\n    height=1024, width=1024,\n    seed=9, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth.utils.controlnet import Annotator\nfrom modelscope import snapshot_download\n\n\n\nsnapshot_download(\"sd_lora/Annotators\", allow_file_pattern=\"dpt_hybrid-midas-501f0c75.pt\", local_dir=\"models/Annotators\")\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-Controlnet-Union-alpha\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\n\nimage_1 = pipe(\n    prompt=\"a beautiful Asian girl, full body, red dress, summer\",\n    height=1024, width=1024,\n    seed=6, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_canny = Annotator(\"canny\")(image_1)\nimage_depth = Annotator(\"depth\")(image_1)\n\nimage_2 = pipe(\n    prompt=\"a beautiful Asian girl, full body, red dress, winter\",\n    controlnet_inputs=[\n        ControlNetInput(image=image_canny, scale=0.3, processor_id=\"canny\"),\n        ControlNetInput(image=image_depth, scale=0.3, processor_id=\"depth\"),\n    ],\n    height=1024, width=1024,\n    seed=7, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"jasperai/Flux.1-dev-Controlnet-Upscaler\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\n\nimage_1 = pipe(\n    prompt=\"a photo of a cat, highly detailed\",\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_1 = image_1.resize((2048, 2048))\nimage_2 = pipe(\n    prompt=\"a photo of a cat, highly detailed\",\n    controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],\n    input_image=image_1,\n    denoising_strength=0.99,\n    height=2048, width=2048, tiled=True,\n    seed=1, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-EliGen.py",
    "content": "import random\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n    \n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n    \n    # Font settings\n    try:\n        font = ImageFont.truetype(\"arial\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n    \n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n    \n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/entity_control/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/entity_control/example_{example_id}/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n    negative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=3.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=50,\n            embedded_guidance=3.5,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"DiffSynth-Studio/Eligen\", origin_file_pattern=\"model_bf16.safetensors\"), alpha=1)\n\n# example 1\nglobal_prompt = \"A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\\n\"\nentity_prompts = [\"cliff\", \"sea\", \"moon\", \"sailing boat\", \"a seated beautiful woman\", \"pale blue long dress with soft glow\"]\nexample(pipe, [0], 1, global_prompt, entity_prompts)\n\n# example 2\nglobal_prompt = \"samurai girl wearing a kimono, she's holding a sword  glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render.\"\nentity_prompts = [\"flowing hair\", \"sword glowing with red flame\", \"A cute bird\", \"blue belt\"]\nexample(pipe, [0], 2, global_prompt, entity_prompts)\n\n# example 3\nglobal_prompt = \"Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,\"\nentity_prompts = [\"ancient palace\", \"stone staircase with railings\", \"a traditional monk\", \"a traditional monk\"]\nexample(pipe, [27], 3, global_prompt, entity_prompts)\n\n# example 4\nglobal_prompt = \"A beautiful girl wearing shirt and shorts in the street,  holding a sign 'Entity Control'\"\nentity_prompts = [\"A beautiful girl\", \"sign 'Entity Control'\", \"shorts\", \"shirt\"]\nexample(pipe, [21], 4, global_prompt, entity_prompts)\n\n# example 5\nglobal_prompt = \"A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere.\"\nentity_prompts = [\"crescent yellow moon\", \"a solitary woman\", \"water\", \"swirling blue clouds\"]\nexample(pipe, [0], 5, global_prompt, entity_prompts)\n\n# example 6\nglobal_prompt = \"Snow White and the 6 Dwarfs.\"\nentity_prompts = [\"Dwarf 1\", \"Dwarf 2\", \"Dwarf 3\", \"Snow White\", \"Dwarf 4\", \"Dwarf 5\", \"Dwarf 6\"]\nexample(pipe, [8], 6, global_prompt, entity_prompts)\n\n# example 7, same prompt with different seeds\nseeds = range(5, 9)\nglobal_prompt = \"A beautiful woman wearing white dress, holding a mirror, with a warm light background;\"\nentity_prompts = [\"A beautiful woman\", \"mirror\", \"necklace\", \"glasses\", \"earring\", \"white dress\", \"jewelry headpiece\"]\nexample(pipe, seeds, 7, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-IP-Adapter\", origin_file_pattern=\"ip-adapter.bin\"),\n        ModelConfig(model_id=\"google/siglip-so400m-patch14-384\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\n\norigin_prompt = \"a rabbit in a garden, colorful flowers\"\nimage = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)\nimage.save(\"style image.jpg\")\n\nimage = pipe(prompt=\"A piggy\", height=1280, width=960, seed=42,\n    ipadapter_images=[image], ipadapter_scale=0.7)\nimage.save(\"A piggy.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport numpy as np\n\n# This model has additional requirements.\n# Please install the following packages.\n# pip install facexlib insightface onnxruntime\nsnapshot_download(\n    \"ByteDance/InfiniteYou\",\n    allow_file_pattern=\"supports/insightface/models/antelopev2/*\",\n    local_dir=\"models/ByteDance/InfiniteYou\",\n)\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/image_proj_model.bin\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\"),\n    ],\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/infiniteyou/*\",\n)\n\nheight, width = 1024, 1024\ncontrolnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))\ncontrolnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id=\"None\")]\n\nprompt = \"A man, portrait, cinematic\"\nid_image = \"data/examples/infiniteyou/man.jpg\"\nid_image = Image.open(id_image).convert('RGB')\nimage = pipe(\n    prompt=prompt, seed=1,\n    infinityou_id_image=id_image, infinityou_guidance=1.0,\n    controlnet_inputs=controlnet_inputs,\n    num_inference_steps=50, embedded_guidance=3.5,\n    height=height, width=width,\n)\nimage.save(\"man.jpg\")\n\nprompt = \"A woman, portrait, cinematic\"\nid_image = \"data/examples/infiniteyou/woman.jpg\"\nid_image = Image.open(id_image).convert('RGB')\nimage = pipe(\n    prompt=prompt, seed=1,\n    infinityou_id_image=id_image, infinityou_guidance=1.0,\n    controlnet_inputs=controlnet_inputs,\n    num_inference_steps=50, embedded_guidance=3.5,\n    height=height, width=width,\n)\nimage.save(\"woman.jpg\")"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\nlora = ModelConfig(model_id=\"VoidOc/flux_animal_forest1\", origin_file_pattern=\"20.safetensors\")\npipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA.\n\n# Empty prompt can automatically activate LoRA capabilities.\nimage = pipe(prompt=\"\", seed=0, lora_encoder_inputs=lora)\nimage.save(\"image_1.jpg\")\n\nimage = pipe(prompt=\"\", seed=0)\nimage.save(\"image_1_origin.jpg\")\n\n# Prompt without trigger words can also activate LoRA capabilities.\nimage = pipe(prompt=\"a car\", seed=0, lora_encoder_inputs=lora)\nimage.save(\"image_2.jpg\")\n\nimage = pipe(prompt=\"a car\", seed=0,)\nimage.save(\"image_2_origin.jpg\")\n\n# Adjust the activation intensity through the scale parameter.\nimage = pipe(prompt=\"a cat\", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)\nimage.save(\"image_3.jpg\")\n\nimage = pipe(prompt=\"a cat\", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)\nimage.save(\"image_3_scale.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    # Enable lora hotloading\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cuda\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\npipe.enable_lora_merger()\n\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"cancel13/cxsk\", origin_file_pattern=\"30.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1\", origin_file_pattern=\"merged_lora.safetensors\"),\n)\nimage = pipe(prompt=\"a cat\", seed=0)\nimage.save(\"image_fused.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/FLUX.1-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"flux.jpg\")\n\nimage = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    seed=0, cfg_scale=2, num_inference_steps=50,\n)\nimage.save(\"flux_cfg.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/Nexus-Gen-Editing.py",
    "content": "import importlib\nimport torch\nfrom PIL import Image\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nif importlib.util.find_spec(\"transformers\") is None:\n    raise ImportError(\"You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.\")\nelse:\n    import transformers\n    assert transformers.__version__ == \"4.49.0\", \"Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`.\"\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n    nexus_gen_processor_config=ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/nexusgen/cat.jpg\")\nref_image = Image.open(\"data/examples/nexusgen/cat.jpg\").convert(\"RGB\")\nprompt = \"Add a crown.\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=42, cfg_scale=2.0, num_inference_steps=50,\n    nexus_gen_reference_image=ref_image,\n    height=512, width=512,\n)\nimage.save(\"cat_crown.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/Nexus-Gen-Generation.py",
    "content": "import importlib\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nif importlib.util.find_spec(\"transformers\") is None:\n    raise ImportError(\"You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.\")\nelse:\n    import transformers\n    assert transformers.__version__ == \"4.49.0\", \"Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`.\"\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"generation_decoder.bin\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n    nexus_gen_processor_config=ModelConfig(\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"processor\"),\n)\n\nprompt = \"一只可爱的猫咪\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=0, cfg_scale=3, num_inference_steps=50,\n    height=1024, width=1024,\n)\nimage.save(\"cat.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference/Step1X-Edit.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\nimport numpy as np\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen2.5-VL-7B-Instruct\", origin_file_pattern=\"model-*.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"vae.safetensors\"),\n    ],\n)\n\nimage = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)\nimage = pipe(\n    prompt=\"draw red flowers in Chinese ink painting style\",\n    step1x_reference_image=image,\n    width=832, height=1248, cfg_scale=6,\n    seed=1, rand_device='cuda'\n)\nimage.save(\"image_1.jpg\")\n\nimage = pipe(\n    prompt=\"add more flowers in Chinese ink painting style\",\n    step1x_reference_image=image,\n    width=832, height=1248, cfg_scale=6,\n    seed=2, rand_device='cuda'\n)\nimage.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLEX.2-preview.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth.utils.controlnet import Annotator\nimport numpy as np\nfrom PIL import Image\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ostris/Flex.2-preview\", origin_file_pattern=\"Flex.2-preview.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    seed=0\n)\nimage.save(\"image_1.jpg\")\n\nmask = np.zeros((1024, 1024, 3), dtype=np.uint8)\nmask[200:400, 400:700] = 255\nmask = Image.fromarray(mask)\nmask.save(\"image_mask.jpg\")\n\ninpaint_image = image\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,\n    seed=4\n)\nimage.save(\"image_2.jpg\")\n\ncontrol_image = Annotator(\"canny\")(image)\ncontrol_image.save(\"image_control.jpg\")\n\nimage = pipe(\n    prompt=\"portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach\",\n    num_inference_steps=50, embedded_guidance=3.5,\n    flex_control_image=control_image,\n    seed=4\n)\nimage.save(\"image_3.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Kontext-dev\", origin_file_pattern=\"flux1-kontext-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage_1 = pipe(\n    prompt=\"a beautiful Asian long-haired female college student.\",\n    embedded_guidance=2.5,\n    seed=1,\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_2 = pipe(\n    prompt=\"transform the style to anime style.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=2,\n)\nimage_2.save(\"image_2.jpg\")\n\nimage_3 = pipe(\n    prompt=\"let her smile.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=3,\n)\nimage_3.save(\"image_3.jpg\")\n\nimage_4 = pipe(\n    prompt=\"let the girl play basketball.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=4,\n)\nimage_4.save(\"image_4.jpg\")\n\nimage_5 = pipe(\n    prompt=\"move the girl to a park, let her sit on a chair.\",\n    kontext_images=image_1,\n    embedded_guidance=2.5,\n    seed=5,\n)\nimage_5.save(\"image_5.jpg\")"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Krea-dev\", origin_file_pattern=\"flux1-krea-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"An beautiful woman is riding a bicycle in a park, wearing a red dress\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n\nimage = pipe(prompt=prompt, seed=0, embedded_guidance=4.5)\nimage.save(\"flux_krea.jpg\")\n\nimage = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    seed=0, cfg_scale=2, num_inference_steps=50,\n    embedded_guidance=4.5\n)\nimage.save(\"flux_krea_cfg.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/AttriCtrl-FLUX.1-Dev\", origin_file_pattern=\"models/brightness.safetensors\", **vram_config)\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nfor i in [0.1, 0.3, 0.5, 0.7, 0.9]:\n    image = pipe(prompt=\"a cat on the beach\", seed=2, value_controller_inputs=[i])\n    image.save(f\"value_control_{i}.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nimport numpy as np\nfrom PIL import Image\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage_1 = pipe(\n    prompt=\"a cat sitting on a chair\",\n    height=1024, width=1024,\n    seed=8, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nmask = np.zeros((1024, 1024, 3), dtype=np.uint8)\nmask[100:350, 350: -300] = 255\nmask = Image.fromarray(mask)\nmask.save(\"mask.jpg\")\n\nimage_2 = pipe(\n    prompt=\"a cat sitting on a chair, wearing sunglasses\",\n    controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],\n    height=1024, width=1024,\n    seed=9, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth.utils.controlnet import Annotator\nfrom modelscope import snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\nsnapshot_download(\"sd_lora/Annotators\", allow_file_pattern=\"dpt_hybrid-midas-501f0c75.pt\", local_dir=\"models/Annotators\")\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-Controlnet-Union-alpha\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage_1 = pipe(\n    prompt=\"a beautiful Asian girl, full body, red dress, summer\",\n    height=1024, width=1024,\n    seed=6, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_canny = Annotator(\"canny\")(image_1)\nimage_depth = Annotator(\"depth\")(image_1)\n\nimage_2 = pipe(\n    prompt=\"a beautiful Asian girl, full body, red dress, winter\",\n    controlnet_inputs=[\n        ControlNetInput(image=image_canny, scale=0.3, processor_id=\"canny\"),\n        ControlNetInput(image=image_depth, scale=0.3, processor_id=\"depth\"),\n    ],\n    height=1024, width=1024,\n    seed=7, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"jasperai/Flux.1-dev-Controlnet-Upscaler\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage_1 = pipe(\n    prompt=\"a photo of a cat, highly detailed\",\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage_1.save(\"image_1.jpg\")\n\nimage_1 = image_1.resize((2048, 2048))\nimage_2 = pipe(\n    prompt=\"a photo of a cat, highly detailed\",\n    controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],\n    input_image=image_1,\n    denoising_strength=0.99,\n    height=2048, width=2048, tiled=True,\n    seed=1, rand_device=\"cuda\",\n)\nimage_2.save(\"image_2.jpg\")"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py",
    "content": "import random\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n    \n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n    \n    # Font settings\n    try:\n        font = ImageFont.truetype(\"arial\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n    \n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n    \n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/entity_control/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/entity_control/example_{example_id}/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n    negative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=3.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=50,\n            embedded_guidance=3.5,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"DiffSynth-Studio/Eligen\", origin_file_pattern=\"model_bf16.safetensors\"), alpha=1)\n\n# example 1\nglobal_prompt = \"A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\\n\"\nentity_prompts = [\"cliff\", \"sea\", \"moon\", \"sailing boat\", \"a seated beautiful woman\", \"pale blue long dress with soft glow\"]\nexample(pipe, [0], 1, global_prompt, entity_prompts)\n\n# example 2\nglobal_prompt = \"samurai girl wearing a kimono, she's holding a sword  glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render.\"\nentity_prompts = [\"flowing hair\", \"sword glowing with red flame\", \"A cute bird\", \"blue belt\"]\nexample(pipe, [0], 2, global_prompt, entity_prompts)\n\n# example 3\nglobal_prompt = \"Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,\"\nentity_prompts = [\"ancient palace\", \"stone staircase with railings\", \"a traditional monk\", \"a traditional monk\"]\nexample(pipe, [27], 3, global_prompt, entity_prompts)\n\n# example 4\nglobal_prompt = \"A beautiful girl wearing shirt and shorts in the street,  holding a sign 'Entity Control'\"\nentity_prompts = [\"A beautiful girl\", \"sign 'Entity Control'\", \"shorts\", \"shirt\"]\nexample(pipe, [21], 4, global_prompt, entity_prompts)\n\n# example 5\nglobal_prompt = \"A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere.\"\nentity_prompts = [\"crescent yellow moon\", \"a solitary woman\", \"water\", \"swirling blue clouds\"]\nexample(pipe, [0], 5, global_prompt, entity_prompts)\n\n# example 6\nglobal_prompt = \"Snow White and the 6 Dwarfs.\"\nentity_prompts = [\"Dwarf 1\", \"Dwarf 2\", \"Dwarf 3\", \"Snow White\", \"Dwarf 4\", \"Dwarf 5\", \"Dwarf 6\"]\nexample(pipe, [8], 6, global_prompt, entity_prompts)\n\n# example 7, same prompt with different seeds\nseeds = range(5, 9)\nglobal_prompt = \"A beautiful woman wearing white dress, holding a mirror, with a warm light background;\"\nentity_prompts = [\"A beautiful woman\", \"mirror\", \"necklace\", \"glasses\", \"earring\", \"white dress\", \"jewelry headpiece\"]\nexample(pipe, seeds, 7, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-IP-Adapter\", origin_file_pattern=\"ip-adapter.bin\", **vram_config),\n        ModelConfig(model_id=\"google/siglip-so400m-patch14-384\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\norigin_prompt = \"a rabbit in a garden, colorful flowers\"\nimage = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)\nimage.save(\"style image.jpg\")\n\nimage = pipe(prompt=\"A piggy\", height=1280, width=960, seed=42,\n    ipadapter_images=[image], ipadapter_scale=0.7)\nimage.save(\"A piggy.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport numpy as np\n\n\n# This model has additional requirements.\n# Please install the following packages.\n# pip install facexlib insightface onnxruntime\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\nsnapshot_download(\n    \"ByteDance/InfiniteYou\",\n    allow_file_pattern=\"supports/insightface/models/antelopev2/*\",\n    local_dir=\"models/ByteDance/InfiniteYou\",\n)\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/image_proj_model.bin\", **vram_config),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/infiniteyou/*\",\n)\n\nheight, width = 1024, 1024\ncontrolnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))\ncontrolnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id=\"None\")]\n\nprompt = \"A man, portrait, cinematic\"\nid_image = \"data/examples/infiniteyou/man.jpg\"\nid_image = Image.open(id_image).convert('RGB')\nimage = pipe(\n    prompt=prompt, seed=1,\n    infinityou_id_image=id_image, infinityou_guidance=1.0,\n    controlnet_inputs=controlnet_inputs,\n    num_inference_steps=50, embedded_guidance=3.5,\n    height=height, width=width,\n)\nimage.save(\"man.jpg\")\n\nprompt = \"A woman, portrait, cinematic\"\nid_image = \"data/examples/infiniteyou/woman.jpg\"\nid_image = Image.open(id_image).convert('RGB')\nimage = pipe(\n    prompt=prompt, seed=1,\n    infinityou_id_image=id_image, infinityou_guidance=1.0,\n    controlnet_inputs=controlnet_inputs,\n    num_inference_steps=50, embedded_guidance=3.5,\n    height=height, width=width,\n)\nimage.save(\"woman.jpg\")"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nlora = ModelConfig(model_id=\"VoidOc/flux_animal_forest1\", origin_file_pattern=\"20.safetensors\")\npipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA.\n\n# Empty prompt can automatically activate LoRA capabilities.\nimage = pipe(prompt=\"\", seed=0, lora_encoder_inputs=lora)\nimage.save(\"image_1.jpg\")\n\nimage = pipe(prompt=\"\", seed=0)\nimage.save(\"image_1_origin.jpg\")\n\n# Prompt without trigger words can also activate LoRA capabilities.\nimage = pipe(prompt=\"a car\", seed=0, lora_encoder_inputs=lora)\nimage.save(\"image_2.jpg\")\n\nimage = pipe(prompt=\"a car\", seed=0,)\nimage.save(\"image_2_origin.jpg\")\n\n# Adjust the activation intensity through the scale parameter.\nimage = pipe(prompt=\"a cat\", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)\nimage.save(\"image_3.jpg\")\n\nimage = pipe(prompt=\"a cat\", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)\nimage.save(\"image_3_scale.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\npipe.enable_lora_merger()\n\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"cancel13/cxsk\", origin_file_pattern=\"30.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1\", origin_file_pattern=\"merged_lora.safetensors\"),\n)\nimage = pipe(prompt=\"a cat\", seed=0)\nimage.save(\"image_fused.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/FLUX.1-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her.\"\nnegative_prompt = \"worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,\"\n\nimage = pipe(prompt=prompt, seed=0)\nimage.save(\"flux.jpg\")\n\nimage = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    seed=0, cfg_scale=2, num_inference_steps=50,\n)\nimage.save(\"flux_cfg.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py",
    "content": "import importlib\nimport torch\nfrom PIL import Image\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nif importlib.util.find_spec(\"transformers\") is None:\n    raise ImportError(\"You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.\")\nelse:\n    import transformers\n    assert transformers.__version__ == \"4.49.0\", \"Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`.\"\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    nexus_gen_processor_config=ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/nexusgen/cat.jpg\")\nref_image = Image.open(\"data/examples/nexusgen/cat.jpg\").convert(\"RGB\")\nprompt = \"Add a crown.\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=42, cfg_scale=2.0, num_inference_steps=50,\n    nexus_gen_reference_image=ref_image,\n    height=512, width=512,\n)\nimage.save(\"cat_crown.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py",
    "content": "import importlib\nimport torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\nif importlib.util.find_spec(\"transformers\") is None:\n    raise ImportError(\"You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.\")\nelse:\n    import transformers\n    assert transformers.__version__ == \"4.49.0\", \"Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`.\"\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"generation_decoder.bin\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\", **vram_config),\n    ],\n    nexus_gen_processor_config=ModelConfig(\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"processor\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"一只可爱的猫咪\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=0, cfg_scale=3, num_inference_steps=50,\n    height=1024, width=1024,\n)\nimage.save(\"cat.jpg\")\n"
  },
  {
    "path": "examples/flux/model_inference_low_vram/Step1X-Edit.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\nimport numpy as np\n\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e4m3fn,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen2.5-VL-7B-Instruct\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\", **vram_config),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"vae.safetensors\", **vram_config),\n    ],\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)\nimage = pipe(\n    prompt=\"draw red flowers in Chinese ink painting style\",\n    step1x_reference_image=image,\n    width=832, height=1248, cfg_scale=6,\n    seed=1, rand_device='cuda'\n)\nimage.save(\"image_1.jpg\")\n\nimage = pipe(\n    prompt=\"add more flowers in Chinese ink painting style\",\n    step1x_reference_image=image,\n    width=832, height=1248, cfg_scale=6,\n    seed=2, rand_device='cuda'\n)\nimage.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/full/FLEX.2-preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLEX.2-preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLEX.2-preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLEX.2-preview/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 200 \\\n  --model_id_with_origin_paths \"ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLEX.2-preview_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-Kontext-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-Kontext-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev/metadata.csv \\\n  --data_file_keys \"image,kontext_images\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-Kontext-dev_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"kontext_images\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-Krea-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-Krea-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-Krea-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-Krea-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-Krea-dev_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-AttriCtrl/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-AttriCtrl \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-AttriCtrl/metadata.csv \\\n  --data_file_keys \"image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.value_controller.encoders.0.\" \\\n  --output_path \"./models/train/FLUX.1-dev-AttriCtrl_full\" \\\n  --trainable_models \"value_controller\" \\\n  --extra_inputs \"value_controller_inputs\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Inpainting-Beta/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Inpainting-Beta \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Inpainting-Beta/metadata.csv \\\n  --data_file_keys \"image,controlnet_image,controlnet_inpaint_mask\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.models.0.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image,controlnet_inpaint_mask\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Union-alpha/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Union-alpha \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Union-alpha/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.models.0.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Union-alpha_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image,controlnet_processor_id\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Upscaler/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Upscaler \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Upscaler/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.models.0.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Upscaler_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-IP-Adapter/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-IP-Adapter \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-IP-Adapter/metadata.csv \\\n  --data_file_keys \"image,ipadapter_images\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.ipadapter.\" \\\n  --output_path \"./models/train/FLUX.1-dev-IP-Adapter_full\" \\\n  --trainable_models \"ipadapter\" \\\n  --extra_inputs \"ipadapter_images\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-InfiniteYou/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-InfiniteYou \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-InfiniteYou/metadata.csv \\\n  --data_file_keys \"image,controlnet_image,infinityou_id_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.\" \\\n  --output_path \"./models/train/FLUX.1-dev-InfiniteYou_full\" \\\n  --trainable_models \"controlnet,image_proj_model\" \\\n  --extra_inputs \"controlnet_image,infinityou_id_image,infinityou_guidance\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-LoRA-Encoder/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-LoRA-Encoder \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-LoRA-Encoder/metadata.csv \\\n  --data_file_keys \"image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev:model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.lora_encoder.\" \\\n  --output_path \"./models/train/FLUX.1-dev-LoRA-Encoder_full\" \\\n  --trainable_models \"lora_encoder\" \\\n  --extra_inputs \"lora_encoder_inputs\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/FLUX.1-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/full/Nexus-Gen.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/Nexus-Gen/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/Nexus-Gen \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/Nexus-Gen/metadata.csv \\\n  --data_file_keys \"image,nexus_gen_reference_image\" \\\n  --max_pixels 262144 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-NexusGen-Edit_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"nexus_gen_reference_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/flux/model_training/full/Step1X-Edit.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/Step1X-Edit/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/Step1X-Edit \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/Step1X-Edit/metadata.csv \\\n  --data_file_keys \"image,step1x_reference_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Step1X-Edit_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"step1x_reference_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/flux/model_training/full/accelerate_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/flux/model_training/full/accelerate_config_zero2offload.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: 'cpu'\n  offload_param_device: 'cpu'\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/flux/model_training/full/accelerate_config_zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLEX.2-preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLEX.2-preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLEX.2-preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLEX.2-preview/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLEX.2-preview_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-Kontext-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev/metadata.csv \\\n  --data_file_keys \"image,kontext_images\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-Kontext-dev_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --extra_inputs \"kontext_images\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-Krea-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-Krea-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-Krea-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-Krea-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-Krea-dev_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-AttriCtrl/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-AttriCtrl \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-AttriCtrl/metadata.csv \\\n  --data_file_keys \"image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-AttriCtrl_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"value_controller_inputs\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Inpainting-Beta/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Inpainting-Beta \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Inpainting-Beta/metadata.csv \\\n  --data_file_keys \"image,controlnet_image,controlnet_inpaint_mask\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image,controlnet_inpaint_mask\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Union-alpha/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Union-alpha \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Union-alpha/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Union-alpha_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image,controlnet_processor_id\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-Controlnet-Upscaler/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Upscaler \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-Controlnet-Upscaler/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-Controlnet-Upscaler_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-EliGen/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-EliGen \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-EliGen/metadata.json \\\n  --data_file_keys \"image,eligen_entity_masks\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-EliGen_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --extra_inputs \"eligen_entity_masks,eligen_entity_prompts\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-IP-Adapter/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-IP-Adapter \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-IP-Adapter/metadata.csv \\\n  --data_file_keys \"image,ipadapter_images\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-IP-Adapter_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"ipadapter_images\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev-InfiniteYou/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev-InfiniteYou \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev-InfiniteYou/metadata.csv \\\n  --data_file_keys \"image,controlnet_image,infinityou_id_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev-InfiniteYou_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image,infinityou_id_image,infinityou_guidance\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/FLUX.1-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/Nexus-Gen.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/Nexus-Gen/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/Nexus-Gen \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/Nexus-Gen/metadata.csv \\\n  --data_file_keys \"image,nexus_gen_reference_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-NexusGen-Edit_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --align_to_opensource_format \\\n  --extra_inputs \"nexus_gen_reference_image\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/lora/Step1X-Edit.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/Step1X-Edit/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/Step1X-Edit \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/Step1X-Edit/metadata.csv \\\n  --data_file_keys \"image,step1x_reference_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Step1X-Edit_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"step1x_reference_image\" \\\n  --align_to_opensource_format \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/special/npu_training/FLUX.1-Kontext-dev-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-Kontext-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-Kontext-dev/metadata.csv \\\n  --data_file_keys \"image,kontext_images\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-Kontext-dev_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"kontext_images\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/special/npu_training/FLUX.1-dev-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux/FLUX.1-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux/FLUX.1-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux/FLUX.1-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.1-dev_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux/model_training/train.py",
    "content": "import torch, os, argparse, accelerate\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass FluxTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_1_path=None, tokenizer_2_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n    ):\n        super().__init__()\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_1_config = ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"tokenizer/\") if tokenizer_1_path is None else ModelConfig(tokenizer_1_path)\n        tokenizer_2_config = ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"tokenizer_2/\") if tokenizer_2_path is None else ModelConfig(tokenizer_2_path)\n        self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_1_config=tokenizer_1_config, tokenizer_2_config=tokenizer_2_config)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"embedded_guidance\": 1,\n            \"t5_sequence_length\": 512,\n            \"tiled\": False,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef flux_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_image_size_config(parser)\n    parser.add_argument(\"--tokenizer_1_path\", type=str, default=None, help=\"Path to CLIP tokenizer.\")\n    parser.add_argument(\"--tokenizer_2_path\", type=str, default=None, help=\"Path to T5 tokenizer.\")\n    parser.add_argument(\"--align_to_opensource_format\", default=False, action=\"store_true\", help=\"Whether to align the lora format to opensource format. Only for DiT's LoRA.\")\n    return parser\n\n\ndef convert_lora_format(state_dict, alpha=None):\n    prefix_rename_dict = {\n        \"single_blocks\": \"lora_unet_single_blocks\",\n        \"blocks\": \"lora_unet_double_blocks\",\n    }\n    middle_rename_dict = {\n        \"norm.linear\": \"modulation_lin\",\n        \"to_qkv_mlp\": \"linear1\",\n        \"proj_out\": \"linear2\",\n        \"norm1_a.linear\": \"img_mod_lin\",\n        \"norm1_b.linear\": \"txt_mod_lin\",\n        \"attn.a_to_qkv\": \"img_attn_qkv\",\n        \"attn.b_to_qkv\": \"txt_attn_qkv\",\n        \"attn.a_to_out\": \"img_attn_proj\",\n        \"attn.b_to_out\": \"txt_attn_proj\",\n        \"ff_a.0\": \"img_mlp_0\",\n        \"ff_a.2\": \"img_mlp_2\",\n        \"ff_b.0\": \"txt_mlp_0\",\n        \"ff_b.2\": \"txt_mlp_2\",\n    }\n    suffix_rename_dict = {\n        \"lora_B.weight\": \"lora_up.weight\",\n        \"lora_A.weight\": \"lora_down.weight\",\n    }\n    state_dict_ = {}\n    for name, param in state_dict.items():\n        names = name.split(\".\")\n        if names[-2] != \"lora_A\" and names[-2] != \"lora_B\":\n            names.pop(-2)\n        prefix = names[0]\n        middle = \".\".join(names[2:-2])\n        suffix = \".\".join(names[-2:])\n        block_id = names[1]\n        if middle not in middle_rename_dict:\n            continue\n        rename = prefix_rename_dict[prefix] + \"_\" + block_id + \"_\" + middle_rename_dict[middle] + \".\" + suffix_rename_dict[suffix]\n        state_dict_[rename] = param\n        if rename.endswith(\"lora_up.weight\"):\n            lora_alpha = alpha if alpha is not None else param.shape[-1]\n            state_dict_[rename.replace(\"lora_up.weight\", \"alpha\")] = torch.tensor((lora_alpha,))[0]\n    return state_dict_\n\n\nif __name__ == \"__main__\":\n    parser = flux_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = FluxTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_1_path=args.tokenizer_1_path,\n        tokenizer_2_path=args.tokenizer_2_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=accelerator.device,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n        state_dict_converter=convert_lora_format if args.align_to_opensource_format else lambda x:x,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLEX.2-preview.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ostris/Flex.2-preview\", origin_file_pattern=\"Flex.2-preview.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLEX.2-preview_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nimage = pipe(prompt=\"dog,white and brown dog, sitting on wall, under pink flowers\", seed=0)\nimage.save(\"image_FLEX.2-preview_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Kontext-dev\", origin_file_pattern=\"flux1-kontext-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-Kontext-dev_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"Make the dog turn its head around.\",\n    kontext_images=Image.open(\"data/example_image_dataset/2.jpg\").resize((768, 768)),\n    height=768, width=768,\n    seed=0\n)\nimage.save(\"image_FLUX.1-Kontext-dev_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Krea-dev\", origin_file_pattern=\"flux1-krea-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-Krea-dev_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nimage = pipe(prompt=\"a dog\", seed=0)\nimage.save(\"image_FLUX.1-Krea-dev_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-AttriCtrl.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/AttriCtrl-FLUX.1-Dev\", origin_file_pattern=\"models/brightness.safetensors\")\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-AttriCtrl_full/epoch-0.safetensors\")\npipe.value_controller.encoders[0].load_state_dict(state_dict)\n\nimage = pipe(prompt=\"a cat\", seed=0, value_controller_inputs=0.1, rand_device=\"cuda\")\nimage.save(\"image_FLUX.1-dev-AttriCtrl_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full/epoch-0.safetensors\")\npipe.controlnet.models[0].load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"a cat sitting on a chair, wearing sunglasses\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/inpaint/image_1.jpg\"),\n        inpaint_mask=Image.open(\"data/example_image_dataset/inpaint/mask.jpg\"),\n        scale=0.9\n    )],\n    height=1024, width=1024,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Inpainting-Beta_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-Controlnet-Union-alpha\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-Controlnet-Union-alpha_full/epoch-0.safetensors\")\npipe.controlnet.models[0].load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"a dog\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/canny/image_1.jpg\"),\n        scale=0.9,\n        processor_id=\"canny\",\n    )],\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Union-alpha_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"jasperai/Flux.1-dev-Controlnet-Upscaler\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-Controlnet-Upscaler_full/epoch-0.safetensors\")\npipe.controlnet.models[0].load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"a dog\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/upscale/image_1.jpg\"),\n        scale=0.9\n    )],\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Upscaler_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-IP-Adapter\", origin_file_pattern=\"ip-adapter.bin\"),\n        ModelConfig(model_id=\"google/siglip-so400m-patch14-384\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors\")\npipe.ipadapter.load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"a dog\",\n    ipadapter_images=Image.open(\"data/example_image_dataset/1.jpg\"),\n    height=768, width=768,\n    seed=0\n)\nimage.save(\"image_FLUX.1-dev-IP-Adapter_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/image_proj_model.bin\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-InfiniteYou_full/epoch-0.safetensors\")\nstate_dict_projector = {i.replace(\"image_proj_model.\", \"\"): state_dict[i] for i in state_dict if i.startswith(\"image_proj_model.\")}\npipe.image_proj_model.load_state_dict(state_dict_projector)\nstate_dict_controlnet = {i.replace(\"controlnet.models.0.\", \"\"): state_dict[i] for i in state_dict if i.startswith(\"controlnet.models.0.\")}\npipe.controlnet.models[0].load_state_dict(state_dict_controlnet)\n\nimage = pipe(\n    prompt=\"a man with a red hat\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/infiniteyou/image_1.jpg\"),\n    )],\n    height=1024, width=1024,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-InfiniteYou_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors\")\npipe.lora_encoder.load_state_dict(state_dict)\n\nlora = ModelConfig(model_id=\"VoidOc/flux_animal_forest1\", origin_file_pattern=\"20.safetensors\")\npipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA.\n\nimage = pipe(prompt=\"\", seed=0, lora_encoder_inputs=lora)\nimage.save(\"image_FLUX.1-dev-LoRA-Encoder_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/FLUX.1-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-dev_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nimage = pipe(prompt=\"a dog\", seed=0)\nimage.save(\"image_FLUX.1-dev_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/Nexus-Gen.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/FLUX.1-NexusGen-Edit_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nref_image = Image.open(\"data/example_image_dataset/nexus_gen/image_1.png\").convert(\"RGB\")\nprompt = \"Add a pair of sunglasses.\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=42, cfg_scale=2.0, num_inference_steps=50,\n    nexus_gen_reference_image=ref_image,\n    height=512, width=512,\n)\nimage.save(\"NexusGen-Edit_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_full/Step1X-Edit.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen2.5-VL-7B-Instruct\", origin_file_pattern=\"model-*.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"vae.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Step1X-Edit_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nimage = pipe(\n    prompt=\"Make the dog turn its head around.\",\n    step1x_reference_image=Image.open(\"data/example_image_dataset/2.jpg\").resize((768, 768)),\n    height=768, width=768, cfg_scale=6,\n    seed=0\n)\nimage.save(\"image_Step1X-Edit_full.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLEX.2-preview.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ostris/Flex.2-preview\", origin_file_pattern=\"Flex.2-preview.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLEX.2-preview_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(prompt=\"dog,white and brown dog, sitting on wall, under pink flowers\", seed=0)\nimage.save(\"image_FLEX.2-preview_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Kontext-dev\", origin_file_pattern=\"flux1-kontext-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"Make the dog turn its head around.\",\n    kontext_images=Image.open(\"data/example_image_dataset/2.jpg\").resize((768, 768)),\n    height=768, width=768,\n    seed=0\n)\nimage.save(\"image_FLUX.1-Kontext-dev_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-Krea-dev\", origin_file_pattern=\"flux1-krea-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-Krea-dev_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(prompt=\"a dog\", seed=0)\nimage.save(\"image_FLUX.1-Krea-dev_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-AttriCtrl.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/AttriCtrl-FLUX.1-Dev\", origin_file_pattern=\"models/brightness.safetensors\")\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-AttriCtrl_lora/epoch-3.safetensors\", alpha=1)\n\nimage = pipe(prompt=\"a cat\", seed=0, value_controller_inputs=0.1, rand_device=\"cuda\")\nimage.save(\"image_FLUX.1-dev-AttriCtrl_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"a cat sitting on a chair, wearing sunglasses\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/inpaint/image_1.jpg\"),\n        inpaint_mask=Image.open(\"data/example_image_dataset/inpaint/mask.jpg\"),\n        scale=0.9\n    )],\n    height=1024, width=1024,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Inpainting-Beta_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-Controlnet-Union-alpha\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-Controlnet-Union-alpha_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"a dog\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/canny/image_1.jpg\"),\n        scale=0.9,\n        processor_id=\"canny\",\n    )],\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Union-alpha_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"jasperai/Flux.1-dev-Controlnet-Upscaler\", origin_file_pattern=\"diffusion_pytorch_model.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-Controlnet-Upscaler_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"a dog\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/upscale/image_1.jpg\"),\n        scale=0.9\n    )],\n    height=768, width=768,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-Controlnet-Upscaler_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\n\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-EliGen_lora/epoch-4.safetensors\", alpha=1)\n\nentity_prompts = [\"A beautiful girl\", \"sign 'Entity Control'\", \"shorts\", \"shirt\"]\nglobal_prompt = \"A beautiful girl wearing shirt and shorts in the street,  holding a sign 'Entity Control'\"\nmasks = [Image.open(f\"data/example_image_dataset/eligen/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n# generate image\nimage = pipe(\n    prompt=global_prompt,\n    cfg_scale=1.0,\n    num_inference_steps=50,\n    embedded_guidance=3.5,\n    seed=42,\n    height=1024,\n    width=1024,\n    eligen_entity_prompts=entity_prompts,\n    eligen_entity_masks=masks,\n)\nimage.save(f\"EliGen_lora.png\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"InstantX/FLUX.1-dev-IP-Adapter\", origin_file_pattern=\"ip-adapter.bin\"),\n        ModelConfig(model_id=\"google/siglip-so400m-patch14-384\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"dog,white and brown dog, sitting on wall, under pink flowers\",\n    ipadapter_images=Image.open(\"data/example_image_dataset/1.jpg\"),\n    height=768, width=768,\n    seed=0\n)\nimage.save(\"image_FLUX.1-dev-IP-Adapter_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/image_proj_model.bin\"),\n        ModelConfig(model_id=\"ByteDance/InfiniteYou\", origin_file_pattern=\"infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev-InfiniteYou_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"a man with a red hat\",\n    controlnet_inputs=[ControlNetInput(\n        image=Image.open(\"data/example_image_dataset/infiniteyou/image_1.jpg\"),\n    )],\n    height=1024, width=1024,\n    seed=0, rand_device=\"cuda\",\n)\nimage.save(\"image_FLUX.1-dev-InfiniteYou_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/FLUX.1-dev.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"flux1-dev.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-dev_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(prompt=\"a dog\", seed=0)\nimage.save(\"image_FLUX.1-dev_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/Nexus-Gen.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"model*.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Nexus-GenV2\", origin_file_pattern=\"edit_decoder.bin\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder/model.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"text_encoder_2/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.1-dev\", origin_file_pattern=\"ae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/FLUX.1-NexusGen-Edit_lora/epoch-4.safetensors\", alpha=1)\n\nref_image = Image.open(\"data/example_image_dataset/nexus_gen/image_1.png\").convert(\"RGB\")\nprompt = \"Add a pair of sunglasses.\"\nimage = pipe(\n    prompt=prompt, negative_prompt=\"\",\n    seed=42, cfg_scale=1.0, num_inference_steps=50,\n    nexus_gen_reference_image=ref_image,\n    height=512, width=512,\n)\nimage.save(\"NexusGen-Edit_lora.jpg\")\n"
  },
  {
    "path": "examples/flux/model_training/validate_lora/Step1X-Edit.py",
    "content": "import torch\nfrom diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig\nfrom PIL import Image\n\n\npipe = FluxImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen2.5-VL-7B-Instruct\", origin_file_pattern=\"model-*.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"step1x-edit-i1258.safetensors\"),\n        ModelConfig(model_id=\"stepfun-ai/Step1X-Edit\", origin_file_pattern=\"vae.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Step1X-Edit_lora/epoch-4.safetensors\", alpha=1)\n\nimage = pipe(\n    prompt=\"Make the dog turn its head around.\",\n    step1x_reference_image=Image.open(\"data/example_image_dataset/2.jpg\").resize((768, 768)),\n    height=768, width=768, cfg_scale=6,\n    seed=0\n)\nimage.save(\"image_Step1X-Edit_lora.jpg\")\n"
  },
  {
    "path": "examples/flux2/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/FLUX2.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/FLUX2.html\n"
  },
  {
    "path": "examples/flux2/model_inference/FLUX.2-dev.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image_FLUX.2-dev.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference/FLUX.2-klein-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_FLUX.2-klein-4B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_edit_FLUX.2-klein-4B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference/FLUX.2-klein-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_FLUX.2-klein-9B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_edit_FLUX.2-klein-9B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference/FLUX.2-klein-base-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_FLUX.2-klein-base-4B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_edit_FLUX.2-klein-base-4B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference/FLUX.2-klein-base-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_FLUX.2-klein-base-9B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_edit_FLUX.2-klein-base-9B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference_low_vram/FLUX.2-dev.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene.\"\nimage = pipe(prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50)\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_FLUX.2-klein-4B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_edit_FLUX.2-klein-4B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_FLUX.2-klein-9B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=4)\nimage.save(\"image_edit_FLUX.2-klein-9B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-4B\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_FLUX.2-klein-base-4B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_edit_FLUX.2-klein-base-4B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-9B\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles.\"\nimage = pipe(prompt, seed=0, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_FLUX.2-klein-base-9B.jpg\")\n\nprompt = \"change the color of the clothes to red\"\nimage = pipe(prompt, edit_image=[image], seed=1, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_edit_FLUX.2-klein-base-9B.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/full/FLUX.2-klein-4B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-4B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-4B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-4B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-4B_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-4B_full\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/full/FLUX.2-klein-9B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-9B/*\" --local_dir ./data/diffsynth_example_dataset\n\n# This script is tested on 8*A100\naccelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-9B_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-9B_full\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-base-4B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-base-4B_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-base-4B_full\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-base-9B/*\" --local_dir ./data/diffsynth_example_dataset\n\n# This script is tested on 8*A100\naccelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-9B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-9B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-base-9B_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-base-9B_full\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/full/accelerate_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/flux2/model_training/full/accelerate_config_zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/flux2/model_training/lora/FLUX.2-dev.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-dev-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path \"./models/train/FLUX.2-dev-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-dev:transformer/*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-dev-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/flux2/model_training/lora/FLUX.2-klein-4B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-4B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-4B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-4B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-4B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-4B_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/lora/FLUX.2-klein-9B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-9B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-9B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-9B_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-base-4B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-base-4B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-4B:tokenizer/\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-base-4B_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-base-9B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-9B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-9B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-base-9B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-base-9B_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-dev/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-dev \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-dev/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-dev-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --task \"sft:data_process\"\n\naccelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \\\n  --dataset_base_path \"./models/train/FLUX.2-dev-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-dev:transformer/*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-dev-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --initialize_model_on_cpu \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh",
    "content": "# This script is tested on 8*910B(NPU)\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"flux2/FLUX.2-klein-9B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-9B/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n  --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FLUX.2-klein-9B_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing\n\n# Edit\n\n# modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\n# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors\" \\\n#   --tokenizer_path \"black-forest-labs/FLUX.2-klein-9B:tokenizer/\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/FLUX.2-klein-9B_full\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing\n"
  },
  {
    "path": "examples/flux2/model_training/train.py",
    "content": "import torch, os, argparse, accelerate\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass Flux2ImageTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n    ):\n        super().__init__()\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"))\n        self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"embedded_guidance\": 1.0,\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef flux2_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_image_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--initialize_model_on_cpu\", default=False, action=\"store_true\", help=\"Whether to initialize models on CPU.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = flux2_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = Flux2ImageTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=\"cpu\" if args.initialize_model_on_cpu else accelerator.device,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/FLUX.2-klein-base-4B_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/FLUX.2-klein-base-9B_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_lora/FLUX.2-dev.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-dev\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/FLUX.2-dev-LoRA-splited/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image_FLUX.2-dev_lora.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-4B\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/FLUX.2-klein-base-4B_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py",
    "content": "from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\nimport torch\n\n\npipe = Flux2ImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-base-9B\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-9B\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/FLUX.2-klein-base-9B_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/ltx2/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/LTX-2.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/LTX-2.html\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer_distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_distilled_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n    num_inference_steps=40,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=42,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    num_inference_steps=40,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-in.safetensors\"),\n)\n\nprompt = \"Dolly-in shot: A cheerful girl smiles brightly and says, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' The camera smoothly moves closer to her face, highlighting her enthusiasm and sincerity.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_in_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-left.safetensors\"),\n)\n\nprompt = \"Dolly-left shot: A joyful young woman sits at a minimalist desk with a laptop running Diffsynth-Studio, code and generative visuals glowing on screen. She turns slightly toward the camera and says with a smile, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera smoothly dollies left, revealing a wall of framed open-source project posters, a whiteboard covered in neural network sketches, and a shelf stacked with AI/graphics books beside her.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_left_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-out.safetensors\"),\n)\n\nprompt = \"Dolly-out shot: A joyful young woman smiles warmly and says: 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera slowly dollies out, revealing a bright, modern creative studio filled with plants, whiteboards full of diagrams, and soft natural light from large windows.\"\n\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path=f'ltx2_dolly_out.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-right.safetensors\"),\n)\n\nprompt = \"Dolly-right shot: A happy girl looks up and says happily, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' She sits before a sunlit café table, her open laptop displaying the Github interface. The camera glides right to show a barista crafting coffee in the background, shelves of artisan beans, and a chalkboard menu softly blurred in the bokeh.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_right_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-jib-down.safetensors\"),\n)\nprompt = (\n    \"A girl is very happy, standing on a clean studio floor with soft ambient lighting. \"\n    \"She is speaking directly to the camera: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” \"\n    \"The shot begins with a medium close-up framing her from the waist up. As the camera performs a smooth jib-down movement—\"\n    \"descending vertically downward—it gradually reveals more of the lower portion of the scene. \"\n    \"During the descent, the following elements become visible near the bottom of the frame: \"\n    \"- The polished concrete floor with subtle reflections of the girl’s shoes, \"\n    \"- A small branded mat labeled “Diffsynth-Studio” placed just beneath her feet, \"\n    \"- The lower part of a sleek workstation desk with a glowing logo on its front panel, partially hidden at the start but fully revealed as the camera lowers. \"\n    \"This downward motion provides a dynamic reveal of contextual details that reinforce the professional and creative environment, \"\n    \"while maintaining focus on the girl’s enthusiastic expression throughout.\"\n)\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_jib_down.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-jib-up.safetensors\"),\n)\nprompt = (\n    \"A girl stands happily at a sleek desk with a glowing 'Diffsynth-Studio' logo, saying: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” \"\n    \"The shot starts low—framing her waist, shoes, and a branded floor mat—and smoothly jib-ups upward. \"\n    \"As the camera rises, it reveals her smiling face, upper body, and behind her: a bright creative studio with wall art and a large window showing daylight sky. \"\n    \"The final frame fully shows the inspiring workspace above the initial view, ensuring spatial continuity.\"\n)\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_jib_up.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Static\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-static.safetensors\"),\n)\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_static.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer_distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_distilled.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2-19b-IC-LoRA-Detailer\", origin_file_pattern=\"ltx-2-19b-ic-lora-detailer.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\n\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 1\nframe_rate = 24\n# the frame rate of the video should better be the same with the reference video\n# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor\ninput_video = VideoData(\"data/example_video_dataset/ltx2/video1.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_iclora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2-19b-IC-LoRA-Union-Control\", origin_file_pattern=\"ltx-2-19b-ic-lora-union-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\n\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\n# the frame rate of the video should better be the same with the reference video\n# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_iclora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n# )\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n# )\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data.audio import read_audio\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"A beautiful woman with a flower crown is singing happily under a blooming cherry tree.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24\nduration = num_frames / frame_rate\naudio, audio_sample_rate = read_audio(\"data/example_video_dataset/ltx2/sing.MP3\", start_time=1, duration=duration)\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    retake_audio=audio,\n    audio_sample_rate=audio_sample_rate,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_a2v.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_distilled_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\n\nprompt = \"Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nfirst_frame = Image.open(\"data/example_video_dataset/ltx2/first_frame.png\").convert(\"RGB\").resize((width, height))\nlast_frame = Image.open(\"data/example_video_dataset/ltx2/last_frame.png\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=42,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    input_images=[first_frame],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_i2av_first.mp4',\n    fps=24,\n)\npipe.clear_lora()\n\n# This example uses the first and last frames for demonstration. However, you can use any frames by setting input_images and input_indexes. Note that input_indexes must be within the range of num_frames.\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=42,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    input_images=[first_frame, last_frame],\n    input_images_indexes=[0, num_frames-1],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_i2av_first_last.mp4',\n    fps=24,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_distilled.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control\", origin_file_pattern=\"ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video1.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)[0]\ninput_video = VideoData(\"data/example_video_dataset/ltx2/spatial_tracker_v2.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2).raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    input_images=[input_image],\n    input_images_indexes=[0],\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_ic_lora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control\", origin_file_pattern=\"ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2).raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_ic_lora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2.3-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2.3\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n# )\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data.audio import read_audio\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"A beautiful woman with a flower crown is singing happily under a blooming cherry tree. She sings: 'Mummy don't know daddy's getting hot. At the body shop'\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\n\nheight, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24\npath = \"data/example_video_dataset/ltx2/video2.mp4\"\nvideo = VideoData(path, height=height, width=width).raw_data()[:num_frames]\nassert len(video) == num_frames, f\"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument.\"\nduration = num_frames / frame_rate\naudio, audio_sample_rate = read_audio(path)\n\n# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.\n# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    retake_video=video,\n    retake_video_regions=[(1, 2), (3, 4)],\n    retake_audio=audio,\n    audio_sample_rate=audio_sample_rate,\n    retake_audio_regions=[(0, 1), (4, 5)],\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_retake.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n)\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer_distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_distilled_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n    num_inference_steps=40,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=42,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    num_inference_steps=40,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-in.safetensors\"),\n)\n\nprompt = \"Dolly-in shot: A cheerful girl smiles brightly and says, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' The camera smoothly moves closer to her face, highlighting her enthusiasm and sincerity.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_in_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-left.safetensors\"),\n)\n\nprompt = \"Dolly-left shot: A joyful young woman sits at a minimalist desk with a laptop running Diffsynth-Studio, code and generative visuals glowing on screen. She turns slightly toward the camera and says with a smile, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera smoothly dollies left, revealing a wall of framed open-source project posters, a whiteboard covered in neural network sketches, and a shelf stacked with AI/graphics books beside her.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_left_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-out.safetensors\"),\n)\n\nprompt = \"Dolly-out shot: A joyful young woman smiles warmly and says: 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera slowly dollies out, revealing a bright, modern creative studio filled with plants, whiteboards full of diagrams, and soft natural light from large windows.\"\n\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path=f'ltx2_dolly_out.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-dolly-right.safetensors\"),\n)\n\nprompt = \"Dolly-right shot: A happy girl looks up and says happily, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' She sits before a sunlit café table, her open laptop displaying the Github interface. The camera glides right to show a barista crafting coffee in the background, shelves of artisan beans, and a chalkboard menu softly blurred in the bokeh.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_dolly_right_lora.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-jib-down.safetensors\"),\n)\nprompt = (\n    \"A girl is very happy, standing on a clean studio floor with soft ambient lighting. \"\n    \"She is speaking directly to the camera: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” \"\n    \"The shot begins with a medium close-up framing her from the waist up. As the camera performs a smooth jib-down movement—\"\n    \"descending vertically downward—it gradually reveals more of the lower portion of the scene. \"\n    \"During the descent, the following elements become visible near the bottom of the frame: \"\n    \"- The polished concrete floor with subtle reflections of the girl’s shoes, \"\n    \"- A small branded mat labeled “Diffsynth-Studio” placed just beneath her feet, \"\n    \"- The lower part of a sleek workstation desk with a glowing logo on its front panel, partially hidden at the start but fully revealed as the camera lowers. \"\n    \"This downward motion provides a dynamic reveal of contextual details that reinforce the professional and creative environment, \"\n    \"while maintaining focus on the girl’s enthusiastic expression throughout.\"\n)\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_jib_down.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-jib-up.safetensors\"),\n)\nprompt = (\n    \"A girl stands happily at a sleek desk with a glowing 'Diffsynth-Studio' logo, saying: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” \"\n    \"The shot starts low—framing her waist, shoes, and a branded floor mat—and smoothly jib-ups upward. \"\n    \"As the camera rises, it reveals her smiling face, upper body, and behind her: a bright creative studio with wall art and a large window showing daylight sky. \"\n    \"The final frame fully shows the inspiring workspace above the initial view, ensuring spatial continuity.\"\n)\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_jib_up.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(\n    pipe.dit,\n    ModelConfig(model_id=\"Lightricks/LTX-2-19b-LoRA-Camera-Control-Static\", origin_file_pattern=\"ltx-2-19b-lora-camera-control-static.safetensors\"),\n)\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_camera_static.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer_distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_distilled.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2-19b-IC-LoRA-Detailer\", origin_file_pattern=\"ltx-2-19b-ic-lora-detailer.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\n\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 1\nframe_rate = 24\n# the frame rate of the video should better be the same with the reference video\n# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor\ninput_video = VideoData(\"data/example_video_dataset/ltx2/video1.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_iclora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2-19b-IC-LoRA-Union-Control\", origin_file_pattern=\"ltx-2-19b-ic-lora-union-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\n\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\n# the frame rate of the video should better be the same with the reference video\n# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage_iclora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\"\"\"\nOffical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2\nRepackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage\nFor base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\"))\nand repackaged checkpoints (with model config ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"*.safetensors\")) are both supported.\nWe have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,\nand avoid redundant memory usage when users only want to use part of the model.\n\"\"\"\n# use the repackaged modelconfig from \"DiffSynth-Studio/LTX-2-Repackage\" to avoid redundant model loading\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n# use the following modelconfig if you want to initialize model from offical checkpoints from \"Lightricks/LTX-2\"\n# pipe = LTX2AudioVideoPipeline.from_pretrained(\n#     torch_dtype=torch.bfloat16,\n#     device=\"cuda\",\n#     model_configs=[\n#         ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-dev.safetensors\", **vram_config),\n#         ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n#     ],\n#     tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n#     stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2\", origin_file_pattern=\"ltx-2-19b-distilled-lora-384.safetensors\"),\n#     vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n# )\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_twostage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data.audio import read_audio\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"A beautiful woman with a flower crown is singing happily under a blooming cherry tree.\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24\nduration = num_frames / frame_rate\naudio, audio_sample_rate = read_audio(\"data/example_video_dataset/ltx2/sing.MP3\", start_time=1, duration=duration)\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    retake_audio=audio,\n    audio_sample_rate=audio_sample_rate,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_a2v.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_distilled_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/ltx-2/first_frame.jpg\"]\n)\nimage = Image.open(\"data/examples/ltx-2/first_frame.jpg\").convert(\"RGB\").resize((width, height))\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=42,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_distilled_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_distilled.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control\", origin_file_pattern=\"ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video1.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)[0]\ninput_video = VideoData(\"data/example_video_dataset/ltx2/spatial_tracker_v2.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2).raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    input_images=[input_image],\n    input_images_indexes=[0],\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_ic_lora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control\", origin_file_pattern=\"ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors\"))\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nref_scale_factor = 2\nframe_rate = 24\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2).raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n    tiled=True,\n    use_two_stage_pipeline=True,\n    clear_lora_before_state_two=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_ic_lora.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data.audio import read_audio\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"ltx2/*\", local_dir=\"data/example_video_dataset\")\nprompt = \"A beautiful woman with a flower crown is singing happily under a blooming cherry tree. She sings: 'Mummy don't know daddy's getting hot. At the body shop'\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\n\nheight, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24\npath = \"data/example_video_dataset/ltx2/video2.mp4\"\nvideo = VideoData(path, height=height, width=width).raw_data()[:num_frames]\nassert len(video) == num_frames, f\"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument.\"\nduration = num_frames / frame_rate\naudio, audio_sample_rate = read_audio(path)\n\n# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.\n# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    retake_video=video,\n    retake_video_regions=[(1, 2), (3, 4)],\n    retake_audio=audio,\n    audio_sample_rate=audio_sample_rate,\n    retake_audio_regions=[(0, 1), (4, 5)],\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage_retake.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.float8_e5m2,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.float8_e5m2,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e5m2,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-dev.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-spatial-upscaler-x2-1.0.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n    stage2_lora_config=ModelConfig(model_id=\"Lightricks/LTX-2.3\", origin_file_pattern=\"ltx-2.3-22b-distilled-lora-384.safetensors\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”\"\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\nheight, width, num_frames = 512 * 2, 768 * 2, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    use_two_stage_pipeline=True,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_twostage.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2-T2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-full-splited-cache\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2-T2AV-full-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2.3-I2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2.3-I2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2.3-I2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-I2AV-full-splited-cache\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2.3-I2AV-full-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-I2AV-full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2.3-T2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV-full-splited-cache\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2.3-T2AV-full-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV-full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2-T2AV-IC-LoRA-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-IC-LoRA-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-IC-LoRA-splited/metadata.json \\\n  --data_file_keys \"video,input_audio,in_context_videos\" \\\n  --extra_inputs \"input_audio,in_context_videos,in_context_downsample_factor,frame_rate\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 81 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-IC-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2-T2AV-IC-LoRA-splited-cache \\\n  --data_file_keys \"video,input_audio,in_context_videos\" \\\n  --extra_inputs \"input_audio,in_context_videos,in_context_downsample_factor,frame_rate\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-IC-LoRA\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2-T2AV-noaudio/*\" --local_dir ./data/diffsynth_example_dataset\n\n# single stage training\n# accelerate launch examples/ltx2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-noaudio \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-noaudio/metadata.csv \\\n#   --height 256 \\\n#   --width 384 \\\n#   --num_frames 25\\\n#   --dataset_repeat 100 \\\n#   --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/LTX2-T2AV-noaudio_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters\n\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-noaudio \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-noaudio/metadata.csv \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121\\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-noaudio_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121\\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV-noaudio_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2-T2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Single Stage Training not recommended for T2AV due to the large memory consumption. Please use the Splited Training instead.\n# accelerate launch examples/ltx2/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited/metadata.csv \\\n#   --data_file_keys \"video,input_audio\" \\\n#   --extra_inputs \"input_audio\" \\\n#   --height 256 \\\n#   --width 384 \\\n#   --num_frames 25\\\n#   --dataset_repeat 100 \\\n#   --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/LTX2-T2AV_lora\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2-T2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2-T2AV_lora-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2-T2AV_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2.3-I2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2.3-I2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2.3-I2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-I2AV_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2.3-I2AV_lora-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-I2AV_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2.3-T2AV-IC-LoRA-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-IC-LoRA-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-IC-LoRA-splited/metadata.json \\\n  --data_file_keys \"video,input_audio,in_context_videos\" \\\n  --extra_inputs \"input_audio,in_context_videos,in_context_downsample_factor,frame_rate\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 81 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV-IC-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2.3-T2AV-IC-LoRA-splited-cache \\\n  --data_file_keys \"video,input_audio,in_context_videos\" \\\n  --extra_inputs \"input_audio,in_context_videos,in_context_downsample_factor,frame_rate\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV-IC-LoRA\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"ltx2/LTX-2.3-T2AV-splited/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Splited Training\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-splited \\\n  --dataset_metadata_path data/diffsynth_example_dataset/ltx2/LTX-2.3-T2AV-splited/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/ltx2/model_training/train.py \\\n  --dataset_base_path ./models/train/LTX2.3-T2AV_lora-splited-cache \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio\" \\\n  --height 512 \\\n  --width 768 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LTX2.3-T2AV_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_k,to_q,to_v,to_out.0\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/ltx2/model_training/scripts/split_model_statedicts.py",
    "content": "from safetensors.torch import save_file\nfrom diffsynth import hash_state_dict_keys\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.models.model_loader import ModelPool\n\nmodel_pool = ModelPool()\nstate_dict = load_state_dict(\"models/Lightricks/LTX-2/ltx-2-19b-dev.safetensors\")\n\ndit_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"model.diffusion_model.\"):\n        new_name = name.replace(\"model.diffusion_model.\", \"\")\n        if new_name.startswith(\"audio_embeddings_connector.\") or new_name.startswith(\"video_embeddings_connector.\"):\n            continue\n        dit_state_dict[name] = state_dict[name]\n\nprint(f\"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}\")\nsave_file(dit_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors\")\nmodel_pool.auto_load_model(\n    \"models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors\",\n)\n\n\nvideo_vae_encoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vae.encoder.\"):\n        video_vae_encoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"vae.per_channel_statistics.\"):\n        video_vae_encoder_state_dict[name] = state_dict[name]\n\nsave_file(video_vae_encoder_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors\")\nprint(f\"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors\")\n\n\nvideo_vae_decoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vae.decoder.\"):\n        video_vae_decoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"vae.per_channel_statistics.\"):\n        video_vae_decoder_state_dict[name] = state_dict[name]\nsave_file(video_vae_decoder_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors\")\nprint(f\"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors\")\n\n\naudio_vae_decoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"audio_vae.decoder.\"):\n        audio_vae_decoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n        audio_vae_decoder_state_dict[name] = state_dict[name]\nsave_file(audio_vae_decoder_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors\")\nprint(f\"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors\")\n\n\naudio_vae_encoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"audio_vae.encoder.\"):\n        audio_vae_encoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n        audio_vae_encoder_state_dict[name] = state_dict[name]\nsave_file(audio_vae_encoder_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors\")\nprint(f\"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors\")\n\n\naudio_vocoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vocoder.\"):\n        audio_vocoder_state_dict[name] = state_dict[name]\nsave_file(audio_vocoder_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors\")\nprint(f\"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors\")\n\n\ntext_encoder_post_modules_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"text_embedding_projection.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\n    elif name.startswith(\"model.diffusion_model.video_embeddings_connector.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\n    elif name.startswith(\"model.diffusion_model.audio_embeddings_connector.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\nsave_file(text_encoder_post_modules_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors\")\nprint(f\"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors\")\n\n\nstate_dict = load_state_dict(\"models/Lightricks/LTX-2/ltx-2-19b-distilled.safetensors\")\ndit_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"model.diffusion_model.\"):\n        new_name = name.replace(\"model.diffusion_model.\", \"\")\n        if new_name.startswith(\"audio_embeddings_connector.\") or new_name.startswith(\"video_embeddings_connector.\"):\n            continue\n        dit_state_dict[name] = state_dict[name]\n\nprint(f\"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}\")\nsave_file(dit_state_dict, \"models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors\")\nmodel_pool.auto_load_model(\n    \"models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors\",\n)"
  },
  {
    "path": "examples/ltx2/model_training/scripts/split_model_statedicts_ltx2.3.py",
    "content": "from safetensors.torch import save_file\nfrom diffsynth import hash_state_dict_keys\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.models.model_loader import ModelPool\nimport os\n\nmodel_pool = ModelPool()\nstate_dict = load_state_dict(\"models/Lightricks/LTX-2.3/ltx-2.3-22b-dev.safetensors\")\nos.makedirs(\"models/DiffSynth-Studio/LTX-2.3-Repackage\", exist_ok=True)\n\ndit_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"model.diffusion_model.\"):\n        new_name = name.replace(\"model.diffusion_model.\", \"\")\n        if new_name.startswith(\"audio_embeddings_connector.\") or new_name.startswith(\"video_embeddings_connector.\"):\n            continue\n        dit_state_dict[name] = state_dict[name]\n\nprint(f\"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}\")\nsave_file(dit_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/transformer.safetensors\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/transformer.safetensors\")\n\n\nvideo_vae_encoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vae.encoder.\"):\n        video_vae_encoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"vae.per_channel_statistics.\"):\n        video_vae_encoder_state_dict[name] = state_dict[name]\n\nsave_file(video_vae_encoder_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_encoder.safetensors\")\nprint(f\"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_encoder.safetensors\")\n\n\nvideo_vae_decoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vae.decoder.\"):\n        video_vae_decoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"vae.per_channel_statistics.\"):\n        video_vae_decoder_state_dict[name] = state_dict[name]\nsave_file(video_vae_decoder_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_decoder.safetensors\")\nprint(f\"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_decoder.safetensors\")\n\n\naudio_vae_decoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"audio_vae.decoder.\"):\n        audio_vae_decoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n        audio_vae_decoder_state_dict[name] = state_dict[name]\nsave_file(audio_vae_decoder_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_decoder.safetensors\")\nprint(f\"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_decoder.safetensors\")\n\n\naudio_vae_encoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"audio_vae.encoder.\"):\n        audio_vae_encoder_state_dict[name] = state_dict[name]\n    elif name.startswith(\"audio_vae.per_channel_statistics.\"):\n        audio_vae_encoder_state_dict[name] = state_dict[name]\nsave_file(audio_vae_encoder_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_encoder.safetensors\")\nprint(f\"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_encoder.safetensors\")\n\n\naudio_vocoder_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"vocoder.\"):\n        audio_vocoder_state_dict[name] = state_dict[name]\nsave_file(audio_vocoder_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vocoder.safetensors\")\nprint(f\"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vocoder.safetensors\")\n\n\ntext_encoder_post_modules_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"text_embedding_projection.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\n    elif name.startswith(\"model.diffusion_model.video_embeddings_connector.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\n    elif name.startswith(\"model.diffusion_model.audio_embeddings_connector.\"):\n        text_encoder_post_modules_state_dict[name] = state_dict[name]\nsave_file(text_encoder_post_modules_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/text_encoder_post_modules.safetensors\")\nprint(f\"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/text_encoder_post_modules.safetensors\")\n\n\nstate_dict = load_state_dict(\"models/Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors\")\ndit_state_dict = {}\nfor name in state_dict:\n    if name.startswith(\"model.diffusion_model.\"):\n        new_name = name.replace(\"model.diffusion_model.\", \"\")\n        if new_name.startswith(\"audio_embeddings_connector.\") or new_name.startswith(\"video_embeddings_connector.\"):\n            continue\n        dit_state_dict[name] = state_dict[name]\n\nprint(f\"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}\")\nsave_file(dit_state_dict, \"models/DiffSynth-Studio/LTX-2.3-Repackage/transformer_distilled.safetensors\")\nmodel_pool.auto_load_model(\"models/DiffSynth-Studio/LTX-2.3-Repackage/transformer_distilled.safetensors\")\n"
  },
  {
    "path": "examples/ltx2/model_training/train.py",
    "content": "import torch, os, argparse, accelerate, warnings\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass LTX2TrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n    ):\n        super().__init__()\n        # Warning\n        if not use_gradient_checkpointing:\n            warnings.warn(\"Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.\")\n            use_gradient_checkpointing = True\n        \n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\") if tokenizer_path is None else ModelConfig(tokenizer_path)\n        self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)\n        self.pipe = self.split_pipeline_units(\n            task, self.pipe, trainable_models, lora_base_model,\n            remove_unnecessary_params=True,\n            force_remove_params_shared=(\"audio_latents\", \"video_latents\"),\n            force_remove_params_nega=(\"audio_context\", \"video_context\")\n        )\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Store other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        \n    def parse_extra_inputs(self, data, extra_inputs, inputs_shared):\n        for extra_input in extra_inputs:\n            if extra_input == \"input_image\":\n                inputs_shared[\"input_images\"] = [data[\"video\"][0]]\n                inputs_shared[\"input_images_indexes\"] = [0]\n                inputs_shared[\"input_images_strength\"] = 1.0\n            else:\n                inputs_shared[extra_input] = data[extra_input]\n        return inputs_shared\n    \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_video\": data[\"video\"],\n            \"height\": data[\"video\"][0].size[1],\n            \"width\": data[\"video\"][0].size[0],\n            \"num_frames\": len(data[\"video\"]),\n            \"frame_rate\": data.get(\"frame_rate\", 24),\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"tiled\": False,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n            \"video_patchifier\": self.pipe.video_patchifier,\n            \"audio_patchifier\": self.pipe.audio_patchifier,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef ltx2_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_video_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--frame_rate\", type=float, default=24, help=\"frame rate of the training videos.\")\n    parser.add_argument(\"--initialize_model_on_cpu\", default=False, action=\"store_true\", help=\"Whether to initialize models on CPU.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = ltx2_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    video_processor = UnifiedDataset.default_video_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=32,\n            width_division_factor=32,\n            num_frames=args.num_frames,\n            time_division_factor=8,\n            time_division_remainder=1,\n            frame_rate=args.frame_rate,\n            fix_frame_rate=True,\n        )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=video_processor,\n        special_operator_map={\n            \"input_audio\": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, frame_rate=args.frame_rate),\n            \"in_context_videos\": RouteByType(operator_map=[\n                (str, video_processor),\n                (list, SequencialProcess(video_processor)),\n            ]),\n        }\n    )\n    model = LTX2TrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=\"cpu\" if args.initialize_model_on_cpu else accelerator.device,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_full/LTX-2-T2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(path=\"./models/train/LTX2-T2AV-full/epoch-4.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    cfg_scale=4.0\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(path=\"./models/train/LTX2.3-I2AV-full/epoch-4.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\n\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nimage = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n    num_inference_steps=40,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(path=\"./models/train/LTX2.3-T2AV-full/epoch-4.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    cfg_scale=4.0\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/LTX2-T2AV-IC-LoRA/epoch-4.safetensors\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 81\nref_scale_factor = 2\nframe_rate = 24\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage_ic.mp4',\n    fps=frame_rate,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"models/train/LTX2-T2AV_lora/epoch-4.safetensors\")\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    cfg_scale=4.0\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors\")\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    cfg_scale=4.0\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=24000,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"models/train/LTX2.3-I2AV_lora/epoch-4.safetensors\")\n\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nimage = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# first frame\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=False,\n    input_images=[image],\n    input_images_indexes=[0],\n    input_images_strength=1.0,\n    num_inference_steps=40,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage_i2av_first.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\nfrom diffsynth.utils.data import VideoData\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_encoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/LTX2.3-T2AV-IC-LoRA/epoch-4.safetensors\")\nprompt = \"[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 81\nref_scale_factor = 2\nframe_rate = 24\ninput_video = VideoData(\"data/example_video_dataset/ltx2/depth_video.mp4\", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)\ninput_video = input_video.raw_data()\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    frame_rate=frame_rate,\n    tiled=True,\n    in_context_videos=[input_video],\n    in_context_downsample_factor=ref_scale_factor,\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2.3_onestage_ic.mp4',\n    fps=frame_rate,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py",
    "content": "import torch\nfrom diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = LTX2AudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\", origin_file_pattern=\"model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"transformer.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"text_encoder_post_modules.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"video_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vae_decoder.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/LTX-2.3-Repackage\", origin_file_pattern=\"audio_vocoder.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\"),\n)\npipe.load_lora(pipe.dit, \"models/train/LTX2.3-T2AV_lora/epoch-4.safetensors\")\nprompt = \"A beautiful sunset over the ocean.\"\nnegative_prompt = \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\nheight, width, num_frames = 512, 768, 121\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=43,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    tiled=True,\n    cfg_scale=4.0\n)\nwrite_video_audio_ltx2(\n    video=video,\n    audio=audio,\n    output_path='ltx2_onestage.mp4',\n    fps=24,\n    audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,\n)\n"
  },
  {
    "path": "examples/mova/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/Wan.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/Wan.html\n"
  },
  {
    "path": "examples/mova/acceleration/unified_sequence_parallel.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\nimport torch.distributed as dist\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    use_usp=True,\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\n\nprompt = \"Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other.\"\nheight, width, num_frames = 352, 640, 121\nframe_rate=24\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((width, height)).convert(\"RGB\")\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nif dist.get_rank() == 0:\n    write_video_audio(video, audio, \"MOVA-360p-cat.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_inference/MOVA-360p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline\nfrom diffsynth.utils.data.audio_video import write_video_audio\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\n\nprompt = \"Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other.\"\nheight, width, num_frames = 352, 640, 121\nframe_rate = 24\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((width, height)).convert(\"RGB\")\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-360p-cat.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_inference/MOVA-720p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\n\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other.\"\nheight, width, num_frames = 720, 1280, 121\nframe_rate = 24\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((width, height)).convert(\"RGB\")\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-720p-cat.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline\nfrom diffsynth.utils.data.audio_video import write_video_audio\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\n\nprompt = \"Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other.\"\nheight, width, num_frames = 352, 640, 121\nframe_rate = 24\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((width, height)).convert(\"RGB\")\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-360p-cat.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other.\"\nheight, width, num_frames = 720, 1280, 121\nframe_rate = 24\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((width, height)).convert(\"RGB\")\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-720p-cat.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_training/full/MOVA-360P-I2AV.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"mova/MOVA-360P-I2AV/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 352 \\\n  --width 640 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-360p-I2AV_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 352 \\\n  --width 640 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-360p-I2AV_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/mova/model_training/full/MOVA-720P-I2AV.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"mova/MOVA-720P-I2AV/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 720 \\\n  --width 1280 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-720p-I2AV_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 720 \\\n  --width 1280 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-720p-I2AV_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/mova/model_training/lora/MOVA-360P-I2AV.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"mova/MOVA-360P-I2AV/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 352 \\\n  --width 640 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-360p-I2AV_high_noise_lora\" \\\n  --lora_base_model \"video_dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-360P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 352 \\\n  --width 640 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-360p-I2AV_low_noise_lora\" \\\n  --lora_base_model \"video_dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/mova/model_training/lora/MOVA-720P-I2AV.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"mova/MOVA-720P-I2AV/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 720 \\\n  --width 1280 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-720p-I2AV_high_noise_lora\" \\\n  --lora_base_model \"video_dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/mova/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV \\\n  --dataset_metadata_path data/diffsynth_example_dataset/mova/MOVA-720P-I2AV/metadata.csv \\\n  --data_file_keys \"video,input_audio\" \\\n  --extra_inputs \"input_audio,input_image\" \\\n  --height 720 \\\n  --width 1280 \\\n  --num_frames 121 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.video_dit.\" \\\n  --output_path \"./models/train/MOVA-720p-I2AV_low_noise_lora\" \\\n  --lora_base_model \"video_dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --use_gradient_checkpointing\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/mova/model_training/train.py",
    "content": "import torch, os, argparse, accelerate, warnings\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass MOVATrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n        max_timestep_boundary=1.0,\n        min_timestep_boundary=0.0,\n    ):\n        super().__init__()\n        # Warning\n        if not use_gradient_checkpointing:\n            warnings.warn(\"Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.\")\n            use_gradient_checkpointing = True\n\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = ModelConfig(model_id=\"google/gemma-3-12b-it-qat-q4_0-unquantized\") if tokenizer_path is None else ModelConfig(tokenizer_path)\n        self.pipe = MovaAudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)\n        self.pipe = self.split_pipeline_units(\n            task, self.pipe, trainable_models, lora_base_model,\n            remove_unnecessary_params=True,\n            force_remove_params_shared=(\"audio_latents\", \"video_latents\"),\n            force_remove_params_nega=(\"audio_context\", \"video_context\")\n        )\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n\n        # Store other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        self.max_timestep_boundary = max_timestep_boundary\n        self.min_timestep_boundary = min_timestep_boundary\n\n    def parse_extra_inputs(self, data, extra_inputs, inputs_shared):\n        for extra_input in extra_inputs:\n            if extra_input == \"input_image\":\n                inputs_shared[\"input_image\"] = data[\"video\"][0]\n            else:\n                inputs_shared[extra_input] = data[extra_input]\n        return inputs_shared\n\n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_video\": data[\"video\"],\n            \"height\": data[\"video\"][0].size[1],\n            \"width\": data[\"video\"][0].size[0],\n            \"num_frames\": len(data[\"video\"]),\n            \"frame_rate\": data.get(\"frame_rate\", 24),\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"tiled\": False,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n            \"max_timestep_boundary\": self.max_timestep_boundary,\n            \"min_timestep_boundary\": self.min_timestep_boundary,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n\n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef ltx2_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_video_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--frame_rate\", type=float, default=24, help=\"Frame rate of the training videos. Mova is trained with a frame rate of 24, so it's recommended to use the same frame rate.\")\n    parser.add_argument(\"--max_timestep_boundary\", type=float, default=1.0, help=\"Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).\")\n    parser.add_argument(\"--min_timestep_boundary\", type=float, default=0.0, help=\"Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).\")\n    parser.add_argument(\"--initialize_model_on_cpu\", default=False, action=\"store_true\", help=\"Whether to initialize models on CPU.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = ltx2_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    model = MOVATrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=\"cpu\" if args.initialize_model_on_cpu else accelerator.device,\n        max_timestep_boundary=args.max_timestep_boundary,\n        min_timestep_boundary=args.min_timestep_boundary,\n    )\n    video_processor = UnifiedDataset.default_video_operator(\n        base_path=args.dataset_base_path,\n        max_pixels=args.max_pixels,\n        height=args.height,\n        width=args.width,\n        height_division_factor=16,\n        width_division_factor=16,\n        num_frames=args.num_frames,\n        time_division_factor=4,\n        time_division_remainder=1,\n        frame_rate=args.frame_rate,\n        fix_frame_rate=True,\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=video_processor,\n        special_operator_map={\n            \"input_audio\":\n                ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(\n                    num_frames=args.num_frames,\n                    time_division_factor=4,\n                    time_division_remainder=1,\n                    frame_rate=args.frame_rate,\n                ),\n            \"in_context_videos\":\n                RouteByType(operator_map=[\n                    (str, video_processor),\n                    (list, SequencialProcess(video_processor)),\n                ]),\n        },\n    )\n\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/mova/model_training/validate_full/MOVA-360p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.utils.data import VideoData\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(path=\"./models/train/MOVA-360p-I2AV_high_noise_full/epoch-4.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"A beautiful sunset over the ocean.\"\nheight, width, num_frames = 352, 640, 121\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-360p.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_training/validate_full/MOVA-720p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data import VideoData\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(path=\"./models/train/MOVA-720p-I2AV_high_noise_full/epoch-4.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.video_dit, \"models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors\")\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"A beautiful sunset over the ocean.\"\nheight, width, num_frames = 720, 1280, 121\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-720p.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.utils.data import VideoData\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-360p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.video_dit, \"models/train/MOVA-360p-I2AV_high_noise_lora/epoch-4.safetensors\")\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"A beautiful sunset over the ocean.\"\nheight, width, num_frames = 352, 640, 121\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-360p.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data.audio_video import write_video_audio\nfrom diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig\nfrom diffsynth.utils.data import VideoData\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = MovaAudioVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"video_dit_2/diffusion_pytorch_model-*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_dit/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"dual_tower_bridge/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"audio_vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"Wan2.1_VAE.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan-Series-Converted-Safetensors\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"openmoss/MOVA-720p\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.video_dit, \"models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors\")\nnegative_prompt = (\n    \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n    \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指\"\n)\nprompt = \"A beautiful sunset over the ocean.\"\nheight, width, num_frames = 720, 1280, 121\nframe_rate = 24\ninput_image = VideoData(\"data/example_video_dataset/ltx2/video.mp4\", height=height, width=width)[0]\n# Image-to-video\nvideo, audio = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=num_frames,\n    input_image=input_image,\n    num_inference_steps=50,\n    seed=0,\n    tiled=True,\n    frame_rate=frame_rate,\n)\nwrite_video_audio(video, audio, \"MOVA-720p.mp4\", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)\n"
  },
  {
    "path": "examples/qwen_image/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/Qwen-Image.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/Qwen-Image.html\n"
  },
  {
    "path": "examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.0\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n)\nimage.save(\"image.jpg\")\n\n# FireRedTeam/FireRed-Image-Edit-1.0 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.1\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n)\nimage.save(\"image.jpg\")\n\n# FireRedTeam/FireRed-Image-Edit-1.1 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-2512.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-2512\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\nprompt = \"a cat with sunglasses\"\ncontrolnet_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1328, 1328))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1328, 1328))\nimage = pipe(\n    prompt, seed=0,\n    input_image=controlnet_image, inpaint_mask=inpaint_mask,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/*.jpg\"\n)\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\n\ncontrolnet_canny_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\ncontrolnet_inpaint_image = Image.open(\"./data/example_image_dataset/canny/image_2.jpg\").convert(\"RGB\").resize((1328, 1328))\n# generate a centered square mask\ninpaint_mask = Image.new(\"L\", controlnet_inpaint_image.size, 0)\nmask_size = 512\nleft = (controlnet_inpaint_image.width - mask_size) // 2\ntop = (controlnet_inpaint_image.height - mask_size) // 2\nright = left + mask_size\nbottom = top + mask_size\ninpaint_mask.paste(255, (left, top, right, bottom))\ninpaint_mask = inpaint_mask.resize((1328, 1328)).convert(\"RGB\")\n\nimage = pipe(\n    prompt, seed=0,\n    input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask,\n    blockwise_controlnet_inputs=[\n        ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0),\n        ControlNetInput(image=controlnet_canny_image, controlnet_id=1),\n    ],\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nfrom modelscope import snapshot_download\nimport torch, math\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\nsnapshot_download(\"MusePublic/Qwen-Image-Distill\", allow_file_pattern=\"qwen_image_distill_3step.safetensors\", cache_dir=\"models\")\nlora_state_dict = load_state_dict(\"models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors\")\nlora_state_dict = {i.replace(\"base_model.model.\", \"\"): j for i, j in lora_state_dict.items()}\npipe.load_lora(pipe.dit, state_dict=lora_state_dict)\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5))\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Distill-Full\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nimport torch\n\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-Distill-LoRA\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-Distill-LoRA\")\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors\")\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2509\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nimage_1 = pipe(prompt=\"一位少女\", seed=0, num_inference_steps=40, height=1328, width=1024)\nimage_1.save(\"image1.jpg\")\n\nimage_2 = pipe(prompt=\"一位老人\", seed=0, num_inference_steps=40, height=1328, width=1024)\nimage_2.save(\"image2.jpg\")\n\nprompt = \"生成这两个人的合影\"\nedit_image = [Image.open(\"image1.jpg\"), Image.open(\"image2.jpg\")]\nimage_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)\nimage_3.save(\"image3.jpg\")\n\n# Qwen-Image-Edit-2509 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport torch\n\n# Load models\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nlora = ModelConfig(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA\",\n    origin_file_pattern=\"model.safetensors\"\n)\npipe.load_lora(pipe.dit, lora)\n\n# Load images\nsnapshot_download(\n    \"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA\",\n    local_dir=\"./data\",\n    allow_file_pattern=\"assets/*\"\n)\nedit_image = [\n    Image.open(\"data/assets/image1_original.png\"),\n    Image.open(\"data/assets/image1_edit_1.png\"),\n    Image.open(\"data/assets/image2_original.png\")\n]\nprompt = \"Edit image 3 based on the transformation from image 1 to image 2.\"\nnegative_prompt = \"泛黄，AI感，不真实，丑陋，油腻的皮肤，异常的肢体，不协调的肢体\"\n\n# Generate\nimage_4 = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=50,\n    height=1280,\n    width=720,\n    zero_cond_t=True,\n)\nimage_4.save(\"image.png\")"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nlora = ModelConfig(\n    model_id=\"lightx2v/Qwen-Image-Edit-2511-Lightning\",\n    origin_file_pattern=\"Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors\"\n)\npipe.load_lora(pipe.dit, lora, alpha=1)\npipe.scheduler = FlowMatchScheduler(\"Qwen-Image-Lightning\")\n\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=4,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n    zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511\n    cfg_scale=1.0,\n)\nimage.save(\"image.jpg\")\n\n# Qwen-Image-Edit-2511 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n    zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511\n)\nimage.save(\"image.jpg\")\n\n# Qwen-Image-Edit-2511 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom modelscope import snapshot_download\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors\")\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)\nimage.save(\"image.jpg\")\n\nprompt = \"将裙子变成粉色\"\nimage = image.resize((512, 384))\nimage = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)\nimage.save(f\"image2.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Edit.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\ninput_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024)\ninput_image.save(\"image1.jpg\")\n\nprompt = \"将裙子改为粉色\"\n# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio\nimage = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)\nimage.save(f\"image2.jpg\")\n\n# edit_image_auto_resize=False: do not resize input image\nimage = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False)\nimage.save(f\"image3.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nimport random\n\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):\n    dataset_snapshot_download(\n        dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n        local_dir=\"./\",\n        allow_file_pattern=f\"data/examples/eligen/poster/example_{example_id}/*.png\"\n    )\n    masks = [\n        Image.open(f\"./data/examples/eligen/poster/example_{example_id}/{i}.png\").convert('RGB').resize((width, height))\n        for i in range(len(entity_prompts))\n    ]\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=40,\n            seed=seed,\n            height=height,\n            width=width,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_poster_example_{example_id}_{seed}.png\")\n        image = Image.new(\"RGB\", (width, height), (0, 0, 0))\n        visualize_masks(image, masks, entity_prompts, f\"eligen_poster_example_{example_id}_mask_{seed}.png\")\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nsnapshot_download(\n    \"DiffSynth-Studio/Qwen-Image-EliGen-Poster\",\n    local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen-Poster\",\n    allow_file_pattern=\"model.safetensors\",\n)\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors\")\nglobal_prompt = \"一张以柔粉紫为背景的海报，左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”，粉紫色椭圆框内白色小字：“图像精确分区控制模型”。右侧有一只小兔子在拆礼物，旁边站着一只头顶迷你烟花发射器的小龙（卡通Q版）。背景有一些白云点缀。整体风格卡通可爱，传达节日惊喜的主题。\"\nentity_prompts = [\"粉紫色文字“Qwen-Image EliGen-Poster”\", \"粉紫色椭圆框内白色小字：“图像精确分区控制模型”\", \"一只小兔子在拆礼物，小兔子旁边站着一只头顶迷你烟花发射器的小龙（卡通Q版）\"]\nseed = [42]\nexample(pipe, seed, 1, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py",
    "content": "import torch\nimport random\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/qwen-image/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png\").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=40,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-EliGen-V2\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen-V2\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors\")\n\nseeds = [0]\n\nglobal_prompt = \"写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background.\"\nentity_prompts = [\"A beautiful woman\", \"mirror\", \"necklace\", \"glasses\", \"earring\", \"white dress\", \"jewelry headpiece\"]\nexample(pipe, seeds, 7, global_prompt, entity_prompts)\n\nglobal_prompt = \"写实摄影风格, 细节丰富。街头一位漂亮的女孩，穿着衬衫和短裤，手持写有“实体控制”的标牌，背景是繁忙的城市街道，阳光明媚，行人匆匆。\"\nentity_prompts = [\"一个漂亮的女孩\", \"标牌 '实体控制'\", \"短裤\", \"衬衫\"]\nexample(pipe, seeds, 4, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-EliGen.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nimport random\n\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/qwen-image/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n    negative_prompt = \"\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=30,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-EliGen\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors\")\n\n# example 1\nglobal_prompt = \"A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\\n\"\nentity_prompts = [\"cliff\", \"sea\", \"moon\", \"sailing boat\", \"a seated beautiful woman\", \"pale blue long dress with soft glow\"]\nexample(pipe, [0], 1, global_prompt, entity_prompts)\n\n# example 2\nglobal_prompt = \"samurai girl wearing a kimono, she's holding a sword  glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render.\"\nentity_prompts = [\"flowing hair\", \"sword glowing with red flame\", \"A cute bird\", \"yellow belt\"]\nexample(pipe, [0], 2, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py",
    "content": "from PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download, snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.utils.controlnet import Annotator\n\nallow_file_pattern = [\"sk_model.pth\", \"sk_model2.pth\", \"dpt_hybrid-midas-501f0c75.pt\", \"ControlNetHED.pth\", \"body_pose_model.pth\", \"hand_pose_model.pth\", \"facenet.pth\", \"scannet.pt\"]\nsnapshot_download(\"lllyasviel/Annotators\", local_dir=\"models/Annotators\", allow_file_pattern=allow_file_pattern)\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-In-Context-Control-Union\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors\")\n\ndataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/qwen-image-context-control/image.jpg\")\norigin_image = Image.open(\"data/examples/qwen-image-context-control/image.jpg\").resize((1024, 1024))\nannotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']\nfor annotator_id in annotator_ids:\n    annotator = Annotator(processor_id=annotator_id, device=\"cuda\")\n    control_image = annotator(origin_image)\n    control_image.save(f\"{annotator.processor_id}.png\")\n\n    control_prompt = \"Context_Control. \"\n    prompt = f\"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞，背景是梦幻的星空，光影交错，细节精致。\"\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)\n    image.save(f\"image_{annotator.processor_id}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control-V2\", origin_file_pattern=\"model.safetensors\"))\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"layer_v2/*.png\"\n)\n\nprompt = \"Text 'APRIL'\"\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_1.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    num_inference_steps=10, cfg_scale=4,\n)\nimage[0].save(\"image_prompt.png\")\n\nmask_image = Image.open(\"data/example_image_dataset/layer_v2/mask_2.png\").convert(\"RGBA\").resize((1024, 1024))\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_2.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    context_image=mask_image,\n    num_inference_steps=10, cfg_scale=1.0,\n)\nimage[0].save(\"image_mask.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\",\n    allow_file_pattern=\"assets/image_1_input.png\",\n    local_dir=\"data/layered_input\"\n)\n\nprompt = \"A cartoon skeleton character wearing a purple hat and holding a gift box\"\ninput_image = Image.open(\"data/layered_input/assets/image_1_input.png\").convert(\"RGBA\").resize((1024, 1024))\nimages = pipe(\n    prompt,\n    seed=0,\n    num_inference_steps=30, cfg_scale=4,\n    height=1024, width=1024,\n    layer_input_image=input_image,\n    layer_num=0,\n)\nimages[0].save(\"image.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-Layered.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_patterns=\"layer/image.png\",\n    local_dir=\"data/example_image_dataset\"\n)\n\n# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.\nprompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word \"HELLO!\" is prominently displayed in bold red letters above the child, while \"Have a Great Day!\" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'\n# Height and width should be consistent with input_image and be divided evenly by 16\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt,\n    seed=1, num_inference_steps=50,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=3,\n)\nfor i, image in enumerate(images):\n    if i == 0: continue # The first image is the input image.\n    image.save(f\"image_{i}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image-i2L.py",
    "content": "from diffsynth.pipelines.qwen_image import (\n    QwenImagePipeline, ModelConfig,\n    QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode\n)\nfrom diffsynth.utils.lora import merge_lora\nfrom diffsynth import load_state_dict\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n\ndef demo_style():\n    # Load models\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Style.safetensors\"),\n        ],\n        processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    )\n\n    # Load images\n    snapshot_download(\n        model_id=\"DiffSynth-Studio/Qwen-Image-i2L\",\n        allow_file_pattern=\"assets/style/1/*\",\n        local_dir=\"data/examples\"\n    )\n    images = [\n        Image.open(\"data/examples/assets/style/1/0.jpg\"),\n        Image.open(\"data/examples/assets/style/1/1.jpg\"),\n        Image.open(\"data/examples/assets/style/1/2.jpg\"),\n        Image.open(\"data/examples/assets/style/1/3.jpg\"),\n        Image.open(\"data/examples/assets/style/1/4.jpg\"),\n    ]\n\n    # Model inference\n    with torch.no_grad():\n        embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n        lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\n    save_file(lora, \"model_style.safetensors\")\n\n\ndef demo_coarse_fine_bias():\n    # Load models\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Coarse.safetensors\"),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Fine.safetensors\"),\n        ],\n        processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    )\n\n    # Load images\n    snapshot_download(\n        model_id=\"DiffSynth-Studio/Qwen-Image-i2L\",\n        allow_file_pattern=\"assets/lora/3/*\",\n        local_dir=\"data/examples\"\n    )\n    images = [\n        Image.open(\"data/examples/assets/lora/3/0.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/1.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/2.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/3.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/4.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/5.jpg\"),\n    ]\n\n    # Model inference\n    with torch.no_grad():\n        embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n        lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\n        lora_bias = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Bias.safetensors\")\n        lora_bias.download_if_necessary()\n        lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device=\"cuda\")\n        lora = merge_lora([lora, lora_bias])\n    save_file(lora, \"model_coarse_fine_bias.safetensors\")\n\n\ndef generate_image(lora_path, prompt, seed):\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ],\n        tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    )\n    pipe.load_lora(pipe.dit, lora_path)\n    image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)\n    return image\n\n\ndemo_style()\nimage = generate_image(\"model_style.safetensors\", \"a cat\", 0)\nimage.save(\"image_1.jpg\")\n\ndemo_coarse_fine_bias()\nimage = generate_image(\"model_coarse_fine_bias.safetensors\", \"bowl\", 1)\nimage.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference/Qwen-Image.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.0\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n)\nimage.save(\"image.jpg\")\n\n# FireRedTeam/FireRed-Image-Edit-1.0 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.1\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n)\nimage.save(\"image.jpg\")\n\n# FireRedTeam/FireRed-Image-Edit-1.1 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-2512\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\nprompt = \"a cat with sunglasses\"\ncontrolnet_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1328, 1328))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1328, 1328))\nimage = pipe(\n    prompt, seed=0,\n    input_image=controlnet_image, inpaint_mask=inpaint_mask,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/*.jpg\"\n)\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\n\ncontrolnet_canny_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\ncontrolnet_inpaint_image = Image.open(\"./data/example_image_dataset/canny/image_2.jpg\").convert(\"RGB\").resize((1328, 1328))\n# generate a centered square mask\ninpaint_mask = Image.new(\"L\", controlnet_inpaint_image.size, 0)\nmask_size = 512\nleft = (controlnet_inpaint_image.width - mask_size) // 2\ntop = (controlnet_inpaint_image.height - mask_size) // 2\nright = left + mask_size\nbottom = top + mask_size\ninpaint_mask.paste(255, (left, top, right, bottom))\ninpaint_mask = inpaint_mask.resize((1328, 1328)).convert(\"RGB\")\n\nimage = pipe(\n    prompt, seed=0,\n    input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask,\n    blockwise_controlnet_inputs=[\n        ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0),\n        ControlNetInput(image=controlnet_canny_image, controlnet_id=1),\n    ],\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nfrom modelscope import snapshot_download\nimport torch, math\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn, # bfloat16 is recommended.\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn, # bfloat16 is recommended.\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nsnapshot_download(\"MusePublic/Qwen-Image-Distill\", allow_file_pattern=\"qwen_image_distill_3step.safetensors\", cache_dir=\"models\")\nlora_state_dict = load_state_dict(\"models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors\", device=\"cuda\", torch_dtype=torch.bfloat16)\nlora_state_dict = {i.replace(\"base_model.model.\", \"\").replace(\".weight\", \".default.weight\"): j for i, j in lora_state_dict.items()}\npipe.load_lora(pipe.dit, state_dict=lora_state_dict, hotload=True)\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5))\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Distill-Full\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nimport torch\n\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-Distill-LoRA\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-Distill-LoRA\")\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors\", hotload=True)\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2509\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\nimage_1 = pipe(prompt=\"一位少女\", seed=0, num_inference_steps=40, height=1328, width=1024)\nimage_1.save(\"image1.jpg\")\n\nimage_2 = pipe(prompt=\"一位老人\", seed=0, num_inference_steps=40, height=1328, width=1024)\nimage_2.save(\"image2.jpg\")\n\nprompt = \"生成这两个人的合影\"\nedit_image = [Image.open(\"image1.jpg\"), Image.open(\"image2.jpg\")]\nimage_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)\nimage_3.save(\"image3.jpg\")\n\n# Qwen-Image-Edit-2509 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-ICEdit.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport torch\n\n# Load models\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nlora = ModelConfig(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA\",\n    origin_file_pattern=\"model.safetensors\"\n)\npipe.load_lora(pipe.dit, lora)\n\n# Load images\nsnapshot_download(\n    \"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA\",\n    local_dir=\"./data\",\n    allow_file_pattern=\"assets/*\"\n)\nedit_image = [\n    Image.open(\"data/assets/image1_original.png\"),\n    Image.open(\"data/assets/image1_edit_1.png\"),\n    Image.open(\"data/assets/image2_original.png\")\n]\nprompt = \"Edit image 3 based on the transformation from image 1 to image 2.\"\nnegative_prompt = \"泛黄，AI感，不真实，丑陋，油腻的皮肤，异常的肢体，不协调的肢体\"\n\n# Generate\nimage_4 = pipe(\n    prompt=prompt, negative_prompt=negative_prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=50,\n    height=1280,\n    width=720,\n    zero_cond_t=True,\n)\nimage_4.save(\"image.png\")"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nlora = ModelConfig(\n    model_id=\"lightx2v/Qwen-Image-Edit-2511-Lightning\",\n    origin_file_pattern=\"Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors\"\n)\npipe.load_lora(pipe.dit, lora, alpha=1)\npipe.scheduler = FlowMatchScheduler(\"Qwen-Image-Lightning\")\n\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=4,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n    zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511\n    cfg_scale=1.0,\n)\nimage.save(\"image.jpg\")\n\n# Qwen-Image-Edit-2511 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_file_pattern=\"qwen_image_edit/*\",\n    local_dir=\"data/example_image_dataset\",\n)\n\nprompt = \"生成这两个人的合影\"\nedit_image = [\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image1.jpg\"),\n    Image.open(\"data/example_image_dataset/qwen_image_edit/image2.jpg\"),\n]\nimage = pipe(\n    prompt,\n    edit_image=edit_image,\n    seed=1,\n    num_inference_steps=40,\n    height=1152,\n    width=896,\n    edit_image_auto_resize=True,\n    zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511\n)\nimage.save(\"image.jpg\")\n\n# Qwen-Image-Edit-2511 is a multi-image editing model.\n# Please use a list to input `edit_image`, even if the input contains only one image.\n# edit_image = [Image.open(\"image.jpg\")]\n# Please do not input the image directly.\n# edit_image = Image.open(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom modelscope import snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors\", hotload=True)\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)\nimage.save(\"image.jpg\")\n\nprompt = \"将裙子变成粉色\"\nimage = image.resize((512, 384))\nimage = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)\nimage.save(f\"image2.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\ninput_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024)\ninput_image.save(\"image1.jpg\")\n\nprompt = \"将裙子改为粉色\"\n# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio\nimage = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)\nimage.save(f\"image2.jpg\")\n\n# edit_image_auto_resize=False: do not resize input image\nimage = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False)\nimage.save(f\"image3.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nimport random\n\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):\n    dataset_snapshot_download(\n        dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n        local_dir=\"./\",\n        allow_file_pattern=f\"data/examples/eligen/poster/example_{example_id}/*.png\"\n    )\n    masks = [\n        Image.open(f\"./data/examples/eligen/poster/example_{example_id}/{i}.png\").convert('RGB').resize((width, height))\n        for i in range(len(entity_prompts))\n    ]\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=40,\n            seed=seed,\n            height=height,\n            width=width,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_poster_example_{example_id}_{seed}.png\")\n        image = Image.new(\"RGB\", (width, height), (0, 0, 0))\n        visualize_masks(image, masks, entity_prompts, f\"eligen_poster_example_{example_id}_mask_{seed}.png\")\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nsnapshot_download(\n    \"DiffSynth-Studio/Qwen-Image-EliGen-Poster\",\n    local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen-Poster\",\n    allow_file_pattern=\"model.safetensors\",\n)\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors\", hotload=True)\nglobal_prompt = \"一张以柔粉紫为背景的海报，左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”，粉紫色椭圆框内白色小字：“图像精确分区控制模型”。右侧有一只小兔子在拆礼物，旁边站着一只头顶迷你烟花发射器的小龙（卡通Q版）。背景有一些白云点缀。整体风格卡通可爱，传达节日惊喜的主题。\"\nentity_prompts = [\"粉紫色文字“Qwen-Image EliGen-Poster”\", \"粉紫色椭圆框内白色小字：“图像精确分区控制模型”\", \"一只小兔子在拆礼物，小兔子旁边站着一只头顶迷你烟花发射器的小龙（卡通Q版）\"]\nseed = [42]\nexample(pipe, seed, 1, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py",
    "content": "import torch\nimport random\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/qwen-image/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png\").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=40,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-EliGen-V2\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen-V2\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors\", hotload=True)\n\nseeds = [0]\n\nglobal_prompt = \"写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background.\"\nentity_prompts = [\"A beautiful woman\", \"mirror\", \"necklace\", \"glasses\", \"earring\", \"white dress\", \"jewelry headpiece\"]\nexample(pipe, seeds, 7, global_prompt, entity_prompts)\n\nglobal_prompt = \"写实摄影风格, 细节丰富。街头一位漂亮的女孩，穿着衬衫和短裤，手持写有“实体控制”的标牌，背景是繁忙的城市街道，阳光明媚，行人匆匆。\"\nentity_prompts = [\"一个漂亮的女孩\", \"标牌 '实体控制'\", \"短裤\", \"衬衫\"]\nexample(pipe, seeds, 4, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom modelscope import dataset_snapshot_download, snapshot_download\nimport random\n\n\ndef visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):\n    # Create a blank image for overlays\n    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))\n\n    colors = [\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n        (165, 238, 173, 80),\n        (76, 102, 221, 80),\n        (221, 160, 77, 80),\n        (204, 93, 71, 80),\n        (145, 187, 149, 80),\n        (134, 141, 172, 80),\n        (157, 137, 109, 80),\n        (153, 104, 95, 80),\n    ]\n    # Generate random colors for each mask\n    if use_random_colors:\n        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]\n\n    # Font settings\n    try:\n        font = ImageFont.truetype(\"wqy-zenhei.ttc\", font_size)  # Adjust as needed\n    except IOError:\n        font = ImageFont.load_default(font_size)\n\n    # Overlay each mask onto the overlay image\n    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):\n        # Convert mask to RGBA mode\n        mask_rgba = mask.convert('RGBA')\n        mask_data = mask_rgba.getdata()\n        new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]\n        mask_rgba.putdata(new_data)\n\n        # Draw the mask prompt text on the mask\n        draw = ImageDraw.Draw(mask_rgba)\n        mask_bbox = mask.getbbox()  # Get the bounding box of the mask\n        text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position\n        draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)\n\n        # Alpha composite the overlay with this mask\n        overlay = Image.alpha_composite(overlay, mask_rgba)\n\n    # Composite the overlay onto the original image\n    result = Image.alpha_composite(image.convert('RGBA'), overlay)\n\n    # Save or display the resulting image\n    result.save(output_path)\n\n    return result\n\ndef example(pipe, seeds, example_id, global_prompt, entity_prompts):\n    dataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/eligen/qwen-image/example_{example_id}/*.png\")\n    masks = [Image.open(f\"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n    negative_prompt = \"\"\n    for seed in seeds:\n        # generate image\n        image = pipe(\n            prompt=global_prompt,\n            cfg_scale=4.0,\n            negative_prompt=negative_prompt,\n            num_inference_steps=30,\n            seed=seed,\n            height=1024,\n            width=1024,\n            eligen_entity_prompts=entity_prompts,\n            eligen_entity_masks=masks,\n        )\n        image.save(f\"eligen_example_{example_id}_{seed}.png\")\n        visualize_masks(image, masks, entity_prompts, f\"eligen_example_{example_id}_mask_{seed}.png\")\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-EliGen\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-EliGen\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors\", hotload=True)\n\n# example 1\nglobal_prompt = \"A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\\n\"\nentity_prompts = [\"cliff\", \"sea\", \"moon\", \"sailing boat\", \"a seated beautiful woman\", \"pale blue long dress with soft glow\"]\nexample(pipe, [0], 1, global_prompt, entity_prompts)\n\n# example 2\nglobal_prompt = \"samurai girl wearing a kimono, she's holding a sword  glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render.\"\nentity_prompts = [\"flowing hair\", \"sword glowing with red flame\", \"A cute bird\", \"yellow belt\"]\nexample(pipe, [0], 2, global_prompt, entity_prompts)\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py",
    "content": "from PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download, snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.utils.controlnet import Annotator\n\nallow_file_pattern = [\"sk_model.pth\", \"sk_model2.pth\", \"dpt_hybrid-midas-501f0c75.pt\", \"ControlNetHED.pth\", \"body_pose_model.pth\", \"hand_pose_model.pth\", \"facenet.pth\", \"scannet.pt\"]\nsnapshot_download(\"lllyasviel/Annotators\", local_dir=\"models/Annotators\", allow_file_pattern=allow_file_pattern)\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nsnapshot_download(\"DiffSynth-Studio/Qwen-Image-In-Context-Control-Union\", local_dir=\"models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union\", allow_file_pattern=\"model.safetensors\")\npipe.load_lora(pipe.dit, \"models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors\", hotload=True)\n\ndataset_snapshot_download(dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\", local_dir=\"./\", allow_file_pattern=f\"data/examples/qwen-image-context-control/image.jpg\")\norigin_image = Image.open(\"data/examples/qwen-image-context-control/image.jpg\").resize((1024, 1024))\nannotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']\nfor annotator_id in annotator_ids:\n    annotator = Annotator(processor_id=annotator_id, device=\"cuda\")\n    control_image = annotator(origin_image)\n    control_image.save(f\"{annotator.processor_id}.png\")\n\n    control_prompt = \"Context_Control. \"\n    prompt = f\"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞，背景是梦幻的星空，光影交错，细节精致。\"\n    negative_prompt = \"网格化，规则的网格，模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴\"\n    image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)\n    image.save(f\"image_{annotator.processor_id}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control-V2\", origin_file_pattern=\"model.safetensors\", **vram_config))\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"layer_v2/*.png\"\n)\n\nprompt = \"Text 'APRIL'\"\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_1.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    num_inference_steps=10, cfg_scale=4,\n)\nimage[0].save(\"image_prompt.png\")\n\nmask_image = Image.open(\"data/example_image_dataset/layer_v2/mask_2.png\").convert(\"RGBA\").resize((1024, 1024))\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_2.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    context_image=mask_image,\n    num_inference_steps=10, cfg_scale=1.0,\n)\nimage[0].save(\"image_mask.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\",\n    allow_file_pattern=\"assets/image_1_input.png\",\n    local_dir=\"data/layered_input\"\n)\n\nprompt = \"A cartoon skeleton character wearing a purple hat and holding a gift box\"\ninput_image = Image.open(\"data/layered_input/assets/image_1_input.png\").convert(\"RGBA\").resize((1024, 1024))\nimages = pipe(\n    prompt,\n    seed=0,\n    num_inference_steps=30, cfg_scale=4,\n    height=1024, width=1024,\n    layer_input_image=input_image,\n    layer_num=0,\n)\nimages[0].save(\"image.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\n\ndataset_snapshot_download(\n    \"DiffSynth-Studio/example_image_dataset\",\n    allow_patterns=\"layer/image.png\",\n    local_dir=\"data/example_image_dataset\"\n)\n\n# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.\nprompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word \"HELLO!\" is prominently displayed in bold red letters above the child, while \"Have a Great Day!\" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'\n# Height and width should be consistent with input_image and be divided evenly by 16\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt,\n    seed=1, num_inference_steps=50,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=3,\n)\nfor i, image in enumerate(images):\n    if i == 0: continue # The first image is the input image.\n    image.save(f\"image_{i}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py",
    "content": "from diffsynth.pipelines.qwen_image import (\n    QwenImagePipeline, ModelConfig,\n    QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode\n)\nfrom diffsynth.utils.lora import merge_lora\nfrom diffsynth import load_state_dict\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\nvram_config_disk_offload = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": \"disk\",\n    \"onload_device\": \"disk\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\ndef demo_style():\n    # Load models\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Style.safetensors\", **vram_config_disk_offload),\n        ],\n        processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n        vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n    )\n\n    # Load images\n    snapshot_download(\n        model_id=\"DiffSynth-Studio/Qwen-Image-i2L\",\n        allow_file_pattern=\"assets/style/1/*\",\n        local_dir=\"data/examples\"\n    )\n    images = [\n        Image.open(\"data/examples/assets/style/1/0.jpg\"),\n        Image.open(\"data/examples/assets/style/1/1.jpg\"),\n        Image.open(\"data/examples/assets/style/1/2.jpg\"),\n        Image.open(\"data/examples/assets/style/1/3.jpg\"),\n        Image.open(\"data/examples/assets/style/1/4.jpg\"),\n    ]\n\n    # Model inference\n    with torch.no_grad():\n        embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n        lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\n    save_file(lora, \"model_style.safetensors\")\n\n\ndef demo_coarse_fine_bias():\n    # Load models\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Coarse.safetensors\", **vram_config_disk_offload),\n            ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Fine.safetensors\", **vram_config_disk_offload),\n        ],\n        processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n        vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n    )\n\n    # Load images\n    snapshot_download(\n        model_id=\"DiffSynth-Studio/Qwen-Image-i2L\",\n        allow_file_pattern=\"assets/lora/3/*\",\n        local_dir=\"data/examples\"\n    )\n    images = [\n        Image.open(\"data/examples/assets/lora/3/0.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/1.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/2.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/3.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/4.jpg\"),\n        Image.open(\"data/examples/assets/lora/3/5.jpg\"),\n    ]\n\n    # Model inference\n    with torch.no_grad():\n        embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n        lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\n        lora_bias = ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-i2L\", origin_file_pattern=\"Qwen-Image-i2L-Bias.safetensors\")\n        lora_bias.download_if_necessary()\n        lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device=\"cuda\")\n        lora = merge_lora([lora, lora_bias])\n    save_file(lora, \"model_coarse_fine_bias.safetensors\")\n\n\ndef generate_image(lora_path, prompt, seed):\n    pipe = QwenImagePipeline.from_pretrained(\n        torch_dtype=torch.bfloat16,\n        device=\"cuda\",\n        model_configs=[\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n            ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ],\n        tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n        vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n    )\n    pipe.load_lora(pipe.dit, lora_path)\n    image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)\n    return image\n\n\ndemo_style()\nimage = generate_image(\"model_style.safetensors\", \"a cat\", 0)\nimage.save(\"image_1.jpg\")\n\ndemo_coarse_fine_bias()\nimage = generate_image(\"model_coarse_fine_bias.safetensors\", \"bowl\", 1)\nimage.save(\"image_2.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_inference_low_vram/Qwen-Image.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.float8_e4m3fn,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.float8_e4m3fn,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt, seed=0, num_inference_steps=40)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/FireRed-Image-Edit-1.0/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.0 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.0/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"FireRedTeam/FireRed-Image-Edit-1.0:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FireRed-Image-Edit-1.0_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/FireRed-Image-Edit-1.1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.1/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"FireRedTeam/FireRed-Image-Edit-1.1:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FireRed-Image-Edit-1.1_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-2512.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-2512/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-2512 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-2512/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-2512_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Canny/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors\" \\\n  --learning_rate 1e-3 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full\" \\\n  --trainable_models \"blockwise_controlnet\" \\\n  --extra_inputs \"blockwise_controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n\n# If you want to pre-train a Blockwise ControlNet from scratch,\n# please run the following script to first generate the initialized model weights file,\n# and then start training with a high learning rate (1e-3).\n\n# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py\n\n# accelerate launch examples/qwen_image/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny/metadata.csv \\\n#   --data_file_keys \"image,blockwise_controlnet_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n#   --model_paths '[\"models/blockwise_controlnet.safetensors\"]' \\\n#   --learning_rate 1e-3 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n#   --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full\" \\\n#   --trainable_models \"blockwise_controlnet\" \\\n#   --extra_inputs \"blockwise_controlnet_image\" \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Depth/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors\" \\\n  --learning_rate 1e-3 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full\" \\\n  --trainable_models \"blockwise_controlnet\" \\\n  --extra_inputs \"blockwise_controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n\n# If you want to pre-train a Blockwise ControlNet from scratch,\n# please run the following script to first generate the initialized model weights file,\n# and then start training with a high learning rate (1e-3).\n\n# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py\n\n# accelerate launch examples/qwen_image/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth/metadata.csv \\\n#   --data_file_keys \"image,blockwise_controlnet_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n#   --model_paths '[\"models/blockwise_controlnet.safetensors\"]' \\\n#   --learning_rate 1e-3 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n#   --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full\" \\\n#   --trainable_models \"blockwise_controlnet\" \\\n#   --extra_inputs \"blockwise_controlnet_image\" \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors\" \\\n  --learning_rate 1e-3 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full\" \\\n  --trainable_models \"blockwise_controlnet\" \\\n  --extra_inputs \"blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n\n# If you want to pre-train a Inpaint Blockwise ControlNet from scratch,\n# please run the following script to first generate the initialized model weights file,\n# and then start training with a high learning rate (1e-3).\n\n# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py\n\n# accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint/metadata.csv \\\n#   --data_file_keys \"image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n#   --model_paths '[\"models/blockwise_controlnet_inpaint.safetensors\"]' \\\n#   --learning_rate 1e-3 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.blockwise_controlnet.models.0.\" \\\n#   --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full\" \\\n#   --trainable_models \"blockwise_controlnet\" \\\n#   --extra_inputs \"blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Distill-Full/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-Full \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-Full/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Distill-Full_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2509/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2509_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2511_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters \\\n  --zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Edit.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit/metadata.csv \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Layered-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control/metadata.json \\\n  --data_file_keys \"image,layer_input_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Layered-Control_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"layer_num,layer_input_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image-Layered.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Layered/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered/metadata.json \\\n  --data_file_keys \"image,layer_input_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Layered_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"layer_num,layer_input_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/Qwen-Image.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/accelerate_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: 'cpu'\n  offload_param_device: 'cpu'\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/qwen_image/model_training/full/accelerate_config_zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/FireRed-Image-Edit-1.0/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.0 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.0/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"FireRedTeam/FireRed-Image-Edit-1.0:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FireRed-Image-Edit-1.0_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/FireRed-Image-Edit-1.1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/FireRed-Image-Edit-1.1/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"FireRedTeam/FireRed-Image-Edit-1.1:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/FireRed-Image-Edit-1.1_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-2512.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-2512/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-2512 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-2512/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-2512_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Canny/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Canny/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"blockwise_controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Depth/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Depth/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"blockwise_controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Blockwise-ControlNet-Inpaint/metadata.csv \\\n  --data_file_keys \"image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"blockwise_controlnet_image,blockwise_controlnet_inpaint_mask\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Distill-Full/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-Full \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-Full/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Distill-Full_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Distill-LoRA/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-LoRA \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Distill-LoRA/metadata.csv \\\n  --data_file_keys \"image\" \\\n  --extra_inputs \"seed,rand_device,num_inference_steps,cfg_scale\" \\\n  --height 1328 \\\n  --width 1328 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Distill-LoRA_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task direct_distill\n\n# This is an experimental training feature designed to directly distill the model, enabling generation results with fewer steps to approximate those achieved with more steps.\n# The model (https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) is trained using this script.\n# The sample dataset is provided solely to demonstrate the dataset format. For actual usage, please construct a larger dataset using the base model.\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2509/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2509_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2511/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2511/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2511_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit/metadata.csv \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-EliGen-Poster/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-EliGen-Poster \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-EliGen-Poster/metadata.json \\\n  --data_file_keys \"image,eligen_entity_masks\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-EliGen-Poster_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"eligen_entity_masks,eligen_entity_prompts\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters \\\n  --lora_checkpoint \"models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-EliGen/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-EliGen \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-EliGen/metadata.json \\\n  --data_file_keys \"image,eligen_entity_masks\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-EliGen_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"eligen_entity_masks,eligen_entity_prompts\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-In-Context-Control-Union/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-In-Context-Control-Union \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-In-Context-Control-Union/metadata.csv \\\n  --data_file_keys \"image,context_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-In-Context-Control-Union_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 64 \\\n  --lora_checkpoint \"models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors\" \\\n  --extra_inputs \"context_image\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters\n\n# if you want to train from scratch, you can remove the --lora_checkpoint argument\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Layered-Control-V2/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control-V2 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control-V2/metadata.json \\\n  --data_file_keys \"image,layer_input_image,context_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Layered-Control-V2_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 64 \\\n  --extra_inputs \"layer_num,layer_input_image,context_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Layered-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered-Control/metadata.json \\\n  --data_file_keys \"image,layer_input_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Layered-Control_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"layer_num,layer_input_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Layered/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Layered/metadata.json \\\n  --data_file_keys \"image,layer_input_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Layered_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"layer_num,layer_input_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/lora/Qwen-Image.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n"
  },
  {
    "path": "examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py",
    "content": "# This script is for initializing a Qwen-Image-Blockwise-ControlNet\nfrom diffsynth import hash_state_dict_keys\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\nimport torch\nfrom safetensors.torch import save_file\n\n\ncontrolnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device=\"cuda\")\ncontrolnet.init_weight()\nstate_dict_controlnet = controlnet.state_dict()\n\nprint(hash_state_dict_keys(state_dict_controlnet))\nsave_file(state_dict_controlnet, \"models/blockwise_controlnet.safetensors\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py",
    "content": "# This script is for initializing a Inpaint Qwen-Image-ControlNet\nimport torch\nfrom diffsynth import hash_state_dict_keys\nfrom diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet\nfrom safetensors.torch import save_file\n\ncontrolnet = QwenImageBlockWiseControlNet(additional_in_dim=4).to(dtype=torch.bfloat16, device=\"cuda\")\ncontrolnet.init_weight()\nstate_dict_controlnet = controlnet.state_dict()\n\nprint(hash_state_dict_keys(state_dict_controlnet))\nsave_file(state_dict_controlnet, \"models/blockwise_controlnet_inpaint.safetensors\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh",
    "content": "# This script is provided as an example only.\n# Please manually replace the two datasets:\n# the first training dataset should contain content you do not want to generate,\n# and the second training dataset should contain content you do want to generate.\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-deterministic\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-differencial\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --preset_lora_path \"./models/train/Qwen-Image-LoRA-deterministic/epoch-4.safetensors\" \\\n  --preset_lora_model \"dit\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora_fp8\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --fp8_models \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/fp8_training/validate.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image_lora_fp8/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/low_vram_training/Qwen-Image-LoRA.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:data_process\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters\n\naccelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image_lora-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --task \"sft:train\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: true\ndeepspeed_config:\n  deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json\n  zero3_init_flag: true\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": false,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 5e7,\n        \"stage3_prefetch_bucket_size\": 5e7,\n        \"stage3_param_persistence_threshold\": 1e5,\n        \"stage3_max_live_parameters\": 1e8,\n        \"stage3_max_reuse_distance\": 1e8,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    },\n    \"activation_checkpointing\": {\n        \"partition_activations\": false,\n        \"cpu_checkpointing\": false,\n        \"contiguous_memory_optimization\": false\n    },\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-LoRA-NPU.sh",
    "content": "# Due to memory limitations, split training is required to train the model on NPU\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2509/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509/metadata.json \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2509:text_encoder/model*.safetensors,Qwen/Qwen-Image-Edit-2509:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2509-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh",
    "content": "# This script was tested using zero3 and on 8*910B(NPU)\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image-Edit-2509/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit-2509/metadata.json \\\n  --data_file_keys \"image,edit_image\" \\\n  --extra_inputs \"edit_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-Edit-2509_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/npu_training/Qwen-Image-LoRA-NPU.sh",
    "content": "# Due to memory limitations, split training is required to train the model on NPU\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/simple/train.py",
    "content": "import torch, accelerate\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\n\nclass QwenImageTrainingModule(DiffusionTrainingModule):\n    def __init__(self, device):\n        super().__init__()\n        # Load the pipeline\n        self.pipe = QwenImagePipeline.from_pretrained(\n            torch_dtype=torch.bfloat16,\n            device=device,\n            model_configs=[\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n                ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n            ],\n            tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n        )\n        # Switch to training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe,\n            lora_base_model=\"dit\",\n            lora_target_modules=\"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj\",\n            lora_rank=32,\n        )\n\n    def forward(self, data):\n        # Preprocess\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": True,\n            \"use_gradient_checkpointing_offload\": False,\n        }\n        for unit in self.pipe.units:\n            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)\n        # Loss\n        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)\n        return loss\n\nif __name__ == \"__main__\":\n    accelerator = accelerate.Accelerator(\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)],\n    )\n    dataset = UnifiedDataset(\n        base_path=\"data/example_image_dataset\",\n        metadata_path=\"data/example_image_dataset/metadata.csv\",\n        repeat=50,\n        data_file_keys=\"image\",\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=\"data/example_image_dataset\",\n            height=512,\n            width=512,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = QwenImageTrainingModule(accelerator.device)\n    model_logger = ModelLogger(\n        output_path=\"models/toy_model\",\n        remove_prefix_in_ckpt=\"pipe.dit.\",\n    )\n    launch_training_task(\n        accelerator, dataset, model, model_logger,\n        learning_rate=1e-5, num_epochs=1,\n    )\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/split_training/Qwen-Image-LoRA.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"qwen_image/Qwen-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/qwen_image/Qwen-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/qwen_image/Qwen-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:data_process\"\n\naccelerate launch examples/qwen_image/model_training/train.py \\\n  --dataset_base_path \"./models/train/Qwen-Image-LoRA-splited-cache\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Qwen-Image-LoRA-splited\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --find_unused_parameters \\\n  --task \"sft:train\"\n"
  },
  {
    "path": "examples/qwen_image/model_training/special/split_training/validate.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-LoRA-splited/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/train.py",
    "content": "import torch, os, argparse, accelerate\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\nfrom diffsynth.core.data.operators import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass QwenImageTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None, processor_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n        zero_cond_t=False,\n    ):\n        super().__init__()\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\") if tokenizer_path is None else ModelConfig(tokenizer_path)\n        processor_config = ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\") if processor_path is None else ModelConfig(processor_path)\n        self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.zero_cond_t = zero_cond_t\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n            \"edit_image_auto_resize\": True,\n            \"zero_cond_t\": self.zero_cond_t,\n        }\n        # Assume you are using this pipeline for inference,\n        # please fill in the input parameters.\n        if isinstance(data[\"image\"], list):\n            inputs_shared.update({\n                \"input_image\": data[\"image\"],\n                \"height\": data[\"image\"][0].size[1],\n                \"width\": data[\"image\"][0].size[0],\n            })\n        else:\n            inputs_shared.update({\n                \"input_image\": data[\"image\"],\n                \"height\": data[\"image\"].size[1],\n                \"width\": data[\"image\"].size[0],\n            })\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef qwen_image_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_image_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--processor_path\", type=str, default=None, help=\"Path to the processor. If provided, the processor will be used for image editing.\")\n    parser.add_argument(\"--zero_cond_t\", default=False, action=\"store_true\", help=\"A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.\")\n    parser.add_argument(\"--initialize_model_on_cpu\", default=False, action=\"store_true\", help=\"Whether to initialize models on CPU.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = qwen_image_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n        ),\n        special_operator_map={\n            # Qwen-Image-Layered\n            \"layer_input_image\": ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16),\n            \"image\": RouteByType(operator_map=[\n                (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)),\n                (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))),\n            ]),\n            \"context_image\": RouteByType(operator_map=[\n                (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)),\n                (None, lambda x: None),\n            ])\n        }\n    )\n    model = QwenImageTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        processor_path=args.processor_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=\"cpu\" if args.initialize_model_on_cpu else accelerator.device,\n        zero_cond_t=args.zero_cond_t,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.0\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nstate_dict = load_state_dict(\"models/train/FireRed-Image-Edit-1.0_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.1\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nstate_dict = load_state_dict(\"models/train/FireRed-Image-Edit-1.1_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-2512\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-2512_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(path=\"models/train/Qwen-Image-Blockwise-ControlNet-Canny_full/epoch-1.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(path=\"models/train/Qwen-Image-Blockwise-ControlNet-Depth_full/epoch-1.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(path=\"models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full/epoch-1.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\nprompt = \"a cat with sunglasses\"\ncontrolnet_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],\n    height=1024, width=1024,\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Distill-Full\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Distill-Full_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2509\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Edit-2509_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Edit-2511_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Edit_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nprompt = \"将裙子改为粉色\"\nimage = Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024))\nimage = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)\nimage.save(f\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Layered-Control_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nprompt = \"Text 'HELLO' and 'Have a great day'\"\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt, seed=0,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=0,\n)\nimages[0].save(\"image.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image-Layered_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nprompt = \"a poster\"\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt, seed=0,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=3,\n)\nfor i, image in enumerate(images):\n    if i == 0: continue # The first image is the input image.\n    image.save(f\"image_{i}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_full/Qwen-Image.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"models/train/Qwen-Image_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.0\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/FireRed-Image-Edit-1.0_lora/epoch-4.safetensors\")\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"FireRedTeam/FireRed-Image-Edit-1.1\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/FireRed-Image-Edit-1.1_lora/epoch-4.safetensors\")\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-2512\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-2512_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora/epoch-4.safetensors\")\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"canny/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"一只小狗，毛发光洁柔顺，眼神灵动，背景是樱花纷飞的春日庭院，唯美温馨。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\nfrom PIL import Image\nimport torch\nfrom modelscope import dataset_snapshot_download\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora/epoch-4.safetensors\")\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1328, 1328))\n\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py",
    "content": "import torch\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora/epoch-4.safetensors\")\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\nprompt = \"a cat with sunglasses\"\ncontrolnet_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],\n    height=1024, width=1024,\n    num_inference_steps=40,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Distill-Full\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Distill-Full_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Distill-LoRA_lora/epoch-4.safetensors\")\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt,\n    seed=0,\n    num_inference_steps=4,\n    cfg_scale=1,\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2509\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Edit-2509_lora/epoch-4.safetensors\")\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit-2511\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Edit-2511_lora/epoch-4.safetensors\")\n\nprompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\nimages = [\n    Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n    Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n]\nimage = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=None,\n    processor_config=ModelConfig(model_id=\"Qwen/Qwen-Image-Edit\", origin_file_pattern=\"processor/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Edit_lora/epoch-4.safetensors\")\n\nprompt = \"将裙子改为粉色\"\nimage = Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024))\nimage = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)\nimage.save(f\"image.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-EliGen-Poster_lora/epoch-4.safetensors\")\n\n\nentity_prompts = [\"A beautiful girl\", \"sign 'Entity Control'\", \"shorts\", \"shirt\"]\nglobal_prompt = \"A beautiful girl wearing shirt and shorts in the street,  holding a sign 'Entity Control'\"\nmasks = [Image.open(f\"data/example_image_dataset/eligen/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n\nimage = pipe(global_prompt,\n             seed=0,\n             height=1024,\n             width=1024,\n             eligen_entity_prompts=entity_prompts,\n             eligen_entity_masks=masks)\nimage.save(\"Qwen-Image-EliGen-Poster.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\nfrom PIL import Image\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-EliGen_lora/epoch-4.safetensors\")\n\n\nentity_prompts = [\"A beautiful girl\", \"sign 'Entity Control'\", \"shorts\", \"shirt\"]\nglobal_prompt = \"A beautiful girl wearing shirt and shorts in the street,  holding a sign 'Entity Control'\"\nmasks = [Image.open(f\"data/example_image_dataset/eligen/{i}.png\").convert('RGB') for i in range(len(entity_prompts))]\n\nimage = pipe(global_prompt,\n             seed=0,\n             height=1024,\n             width=1024,\n             eligen_entity_prompts=entity_prompts,\n             eligen_entity_masks=masks)\nimage.save(\"Qwen-Image_EliGen.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py",
    "content": "from PIL import Image\nimport torch\nfrom diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-In-Context-Control-Union_lora/epoch-4.safetensors\")\nimage = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1024, 1024))\nprompt = \"Context_Control. a dog\"\nimage = pipe(prompt=prompt, seed=0, context_image=image, height=1024, width=1024)\nimage.save(\"image_context.jpg\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Layered-Control-V2_lora/epoch-4.safetensors\")\n\nprompt = \"Text 'APRIL'\"\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_1.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    num_inference_steps=10, cfg_scale=4,\n)\nimage[0].save(\"image_prompt.png\")\n\nmask_image = Image.open(\"data/example_image_dataset/layer_v2/mask_2.png\").convert(\"RGBA\").resize((1024, 1024))\ninput_image = Image.open(\"data/example_image_dataset/layer_v2/image_2.png\").convert(\"RGBA\").resize((1024, 1024))\nimage = pipe(\n    prompt, seed=0,\n    height=1024, width=1024,\n    layer_input_image=input_image, layer_num=0,\n    context_image=mask_image,\n    num_inference_steps=10, cfg_scale=1.0,\n)\nimage[0].save(\"image_mask.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"DiffSynth-Studio/Qwen-Image-Layered-Control\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Layered-Control_lora/epoch-4.safetensors\")\nprompt = \"Text 'HELLO' and 'Have a great day'\"\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt, seed=0,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=0,\n)\nimages[0].save(\"image.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image-Layered\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image-Layered_lora/epoch-4.safetensors\")\nprompt = \"a poster\"\ninput_image = Image.open(\"data/example_image_dataset/layer/image.png\").convert(\"RGBA\").resize((864, 480))\nimages = pipe(\n    prompt, seed=0,\n    height=480, width=864,\n    layer_input_image=input_image, layer_num=3,\n)\nfor i, image in enumerate(images):\n    if i == 0: continue # The first image is the input image.\n    image.save(f\"image_{i}.png\")\n"
  },
  {
    "path": "examples/qwen_image/model_training/validate_lora/Qwen-Image.py",
    "content": "from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig\nimport torch\n\n\npipe = QwenImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"text_encoder/model*.safetensors\"),\n        ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Qwen/Qwen-Image\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/Qwen-Image_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt, seed=0)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/wanvideo/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/Wan.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/Wan.html\n"
  },
  {
    "path": "examples/wanvideo/acceleration/unified_sequence_parallel.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nimport torch.distributed as dist\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    use_usp=True,\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"一名宇航员身穿太空服，面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方，点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健，扬起微弱的尘埃，展现出未来科技与原始探索的完美结合。宇航员手持操控装置，目光坚定，仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球，画面既科幻又充满希望，让人不禁畅想未来的星际生活。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nif dist.get_rank() == 0:\n    save_video(video, \"video1.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/LongCat-Video.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"meituan-longcat/LongCat-Video\", origin_file_pattern=\"dit/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.\",\n    negative_prompt=\"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n    seed=0, tiled=True, num_frames=93,\n    cfg_scale=2, sigma_shift=1,\n)\nsave_video(video, \"video_1_LongCat-Video.mp4\", fps=15, quality=5)\n\n# Video-continuation (The number of frames in `longcat_video` should be 4n+1.)\nlongcat_video = video[-17:]\nvideo = pipe(\n    prompt=\"In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.\",\n    negative_prompt=\"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n    seed=1, tiled=True, num_frames=93,\n    cfg_scale=2, sigma_shift=1,\n    longcat_video=longcat_video,\n)\nsave_video(video, \"video_2_LongCat-Video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py",
    "content": "import torch\nimport PIL\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom typing import List\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"wanvap/*\", local_dir=\"data/example_video_dataset\")\nref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'\ntarget_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'\n\ndef select_frames(video_frames, num):\n    idx = torch.linspace(0, len(video_frames) - 1, num).long().tolist()\n    return [video_frames[i] for i in idx]\n\nimage = Image.open(target_image_path).convert(\"RGB\")\nref_video = VideoData(ref_video_path, height=480, width=832)\nref_frames = select_frames(ref_video, num=49)\n\nvap_prompt = \"A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery.\"\nprompt = \"A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent.\"\nnegative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    input_image=image,\n    seed=42, tiled=True,\n    height=480, width=832,\n    num_frames=49,\n    vap_video=ref_frames,\n    vap_prompt=vap_prompt,\n    negative_vap_prompt=negative_prompt,\n)\nsave_video(video, \"video_Video-As-Prompt-Wan2.1-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=0\n)\nsave_video(video, \"video_slow_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=100\n)\nsave_video(video, \"video_fast_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/first_frame.jpeg\", \"data/examples/wan/last_frame.jpeg\"]\n)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"写实风格，一个女生手持枯萎的花站在花园中，镜头逐渐拉远，记录下花园的全貌。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=Image.open(\"data/examples/wan/first_frame.jpeg\").resize((960, 960)),\n    end_image=Image.open(\"data/examples/wan/last_frame.jpeg\").resize((960, 960)),\n    seed=0, tiled=True,\n    height=960, width=960, num_frames=33,\n    sigma_shift=16,\n)\nsave_video(video, \"video_Wan2.1-FLF2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/control_video.mp4\"\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/control_video.mp4\"\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# Image-to-video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# Image-to-video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True,\n    height=720, width=1280,\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_1_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n\n# Video-to-video\nvideo = VideoData(\"video_1_Wan2.1-T2V-1.3B.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，戴着黑色墨镜，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_video=video, denoising_strength=0.7,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-T2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"一名宇航员身穿太空服，面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方，点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健，扬起微弱的尘埃，展现出未来科技与原始探索的完美结合。宇航员手持操控装置，目光坚定，仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球，画面既科幻又充满希望，让人不禁畅想未来的星际生活。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_Wan2.1-T2V-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.1-VACE-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-Animate-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download, snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=\"data/examples/wan/animate/*\",\n)\n\n# Animate\ninput_image = Image.open(\"data/examples/wan/animate/animate_input_image.png\")\nanimate_pose_video = VideoData(\"data/examples/wan/animate/animate_pose_video.mp4\").raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/animate_face_video.mp4\").raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    num_frames=81, height=720, width=1280,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_1_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)\n\n# Replace\nsnapshot_download(\"Wan-AI/Wan2.2-Animate-14B\", allow_file_pattern=\"relighting_lora.ckpt\", local_dir=\"models/Wan-AI/Wan2.2-Animate-14B\")\nlora_state_dict = load_state_dict(\"models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt\", torch_dtype=torch.bfloat16, device=\"cuda\")[\"state_dict\"]\npipe.load_lora(pipe.dit, state_dict=lora_state_dict)\ninput_image = Image.open(\"data/examples/wan/animate/replace_input_image.png\")\nanimate_pose_video = VideoData(\"data/examples/wan/animate/replace_pose_video.mp4\").raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/replace_face_video.mp4\").raw_data()[:81-4]\nanimate_inpaint_video = VideoData(\"data/examples/wan/animate/replace_inpaint_video.mp4\").raw_data()[:81-4]\nanimate_mask_video = VideoData(\"data/examples/wan/animate/replace_mask_video.mp4\").raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    animate_inpaint_video=animate_inpaint_video,\n    animate_mask_video=animate_mask_video,\n    num_frames=81, height=720, width=1280,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_2_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video,VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video,VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True,\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/cat_fightning.jpg\"]\n)\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480))\n\nvideo = pipe(\n    prompt=\"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    switch_DiT_boundary=0.9,\n)\nsave_video(video, \"video_Wan2.2-I2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-S2V-14B.py",
    "content": "# This script can generate a single video clip.\n# If you need generate long videos, please refer to `Wan2.2-S2V-14B_multi_clips.py`.\nimport torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n)\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_video_dataset\",\n    local_dir=\"./data/example_video_dataset\",\n    allow_file_pattern=f\"wans2v/*\"\n)\n\nnum_frames = 81 # 4n+1\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n# s2v audio input, recommend 16kHz sampling rate\naudio_path = 'data/example_video_dataset/wans2v/sing.MP3'\ninput_audio, sample_rate = librosa.load(audio_path, sr=16000)\n\n# Speech-to-video\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_1_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n\n# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.\npose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'\npose_video = VideoData(pose_video_path, height=height, width=width)\n\n# Speech-to-video with pose\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    s2v_pose_video=pose_video,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_2_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py",
    "content": "import torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V\nfrom modelscope import dataset_snapshot_download\n\n\ndef speech_to_video(\n    prompt,\n    input_image,\n    audio_path,\n    negative_prompt=\"\",\n    num_clip=None,\n    audio_sample_rate=16000,\n    pose_video_path=None,\n    infer_frames=80,\n    height=448,\n    width=832,\n    num_inference_steps=40,\n    fps=16, # recommend fixing fps as 16 for s2v\n    motion_frames=73, # hyperparameter of wan2.2-s2v\n    save_path=None,\n):\n    # s2v audio input, recommend 16kHz sampling rate\n    input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate)\n    # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.\n    pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None\n\n    with torch.no_grad():\n        audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(\n            pipe=pipe,\n            input_audio=input_audio,\n            audio_sample_rate=sample_rate,\n            s2v_pose_video=pose_video,\n            num_frames=infer_frames + 1,\n            height=height,\n            width=width,\n            fps=fps,\n        )\n    num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat\n    print(f\"Generating {num_repeat} video clips...\")\n    motion_video = None\n    video = []\n    for r in range(num_repeat):\n        s2v_pose_latents = pose_latents[r] if pose_latents is not None else None\n        current_clip_tensor = pipe(\n            prompt=prompt,\n            input_image=input_image,\n            negative_prompt=negative_prompt,\n            seed=0,\n            num_frames=infer_frames + 1,\n            height=height,\n            width=width,\n            audio_embeds=audio_embeds[r],\n            s2v_pose_latents=s2v_pose_latents,\n            motion_video=motion_video,\n            num_inference_steps=num_inference_steps,\n            output_type=\"floatpoint\",\n        )\n        # (B, C, T, H, W)\n        current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]\n        if r == 0:\n            current_clip_tensor = current_clip_tensor[:,:,3:,:,:]\n            overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])\n            motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()\n        else:\n            overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])\n            motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)\n        current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)\n        video.extend(current_clip_quantized)\n        save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)\n        print(f\"processed the {r+1}th clip of total {num_repeat} clips.\")\n    return video\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_video_dataset\",\n    local_dir=\"./data/example_video_dataset\",\n    allow_file_pattern=f\"wans2v/*\",\n)\n\ninfer_frames = 80  # 4n\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n\nvideo_with_audio = speech_to_video(\n    prompt=prompt,\n    input_image=input_image,\n    audio_path='data/example_video_dataset/wans2v/sing.MP3',\n    negative_prompt=negative_prompt,\n    pose_video_path='data/example_video_dataset/wans2v/pose.mp4',\n    save_path=\"video_full_Wan2.2-S2V-14B.mp4\",\n    infer_frames=infer_frames,\n    height=height,\n    width=width,\n)\n# num_clip means generating only the first n clips with n * infer_frames frames.\nvideo_with_audio_pose = speech_to_video(\n    prompt=prompt,\n    input_image=input_image,\n    audio_path='data/example_video_dataset/wans2v/sing.MP3',\n    negative_prompt=negative_prompt,\n    pose_video_path='data/example_video_dataset/wans2v/pose.mp4',\n    save_path=\"video_clip_2_Wan2.2-S2V-14B.mp4\",\n    num_clip=2\n)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-T2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"Wan2.2_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=704, width=1248,\n    num_frames=121,\n)\nsave_video(video, \"video_1_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n\n# Image-to-video\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/cat_fightning.jpg\"]\n)\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((1248, 704))\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=704, width=1248,\n    input_image=input_image,\n    num_frames=121,\n)\nsave_video(video, \"video_2_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py",
    "content": "# Without VRAM Management, 80G VRAM is not enough to run this example.\n# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`.\n# CPU Offload is enabled in this example.\nimport torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/WanToDance-14B-global.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"global_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-global/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.\n# *   When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.\n# *   Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.\n# *   The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是韩舞。帧率是7.5000\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=False,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=48,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg\"),\n    wantodance_fps=7.5,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1] + [0] * 148,\n    framewise_decoding=True,\n)\nsave_video(video, \"video_WanToDance-14B-global.mp4\", fps=7.5, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/WanToDance-14B-local.py",
    "content": "import torch, os\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"local_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-local/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.\n# *   If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.\n# *   In `wantodance_keyframes`, frames that are not keyframes should be solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=24,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg\"),\n    wantodance_fps=30,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1],\n)\nsave_video(video, \"video_WanToDance-14B-local.mp4\", fps=30, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference/krea-realtime-video.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"krea/krea-realtime-video\", origin_file_pattern=\"krea-realtime-video-14b.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"a cat sitting on a boat\",\n    num_inference_steps=6, num_frames=81,\n    seed=0, tiled=True,\n    cfg_scale=1,\n    sigma_shift=20,\n)\nsave_video(video, \"video_krea-realtime-video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/LongCat-Video.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"meituan-longcat/LongCat-Video\", origin_file_pattern=\"dit/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.\",\n    negative_prompt=\"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n    seed=0, tiled=True, num_frames=93,\n    cfg_scale=2, sigma_shift=1,\n)\nsave_video(video, \"video_1_LongCat-Video.mp4\", fps=15, quality=5)\n\n# Video-continuation (The number of frames in `longcat_video` should be 4n+1.)\nlongcat_video = video[-17:]\nvideo = pipe(\n    prompt=\"In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.\",\n    negative_prompt=\"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n    seed=1, tiled=True, num_frames=93,\n    cfg_scale=2, sigma_shift=1,\n    longcat_video=longcat_video,\n)\nsave_video(video, \"video_2_LongCat-Video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py",
    "content": "import torch\nimport PIL\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom typing import List\n\n\n# This model doesn't support fine-grained VRAM Management due to its special architecture.\n# Only CPU Offload is supported.\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\"DiffSynth-Studio/example_video_dataset\", allow_file_pattern=\"wanvap/*\", local_dir=\"data/example_video_dataset\")\nref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'\ntarget_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'\n\ndef select_frames(video_frames, num):\n    idx = torch.linspace(0, len(video_frames) - 1, num).long().tolist()\n    return [video_frames[i] for i in idx]\n\nimage = Image.open(target_image_path).convert(\"RGB\")\nref_video = VideoData(ref_video_path, height=480, width=832)\nref_frames = select_frames(ref_video, num=49)\n\nvap_prompt = \"A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery.\"\nprompt = \"A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent.\"\nnegative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    input_image=image,\n    seed=42, tiled=True,\n    height=480, width=832,\n    num_frames=49,\n    vap_video=ref_frames,\n    vap_prompt=vap_prompt,\n    negative_vap_prompt=negative_prompt,\n)\nsave_video(video, \"video_Video-As-Prompt-Wan2.1-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=0\n)\nsave_video(video, \"video_slow_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=100\n)\nsave_video(video, \"video_fast_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/first_frame.jpeg\", \"data/examples/wan/last_frame.jpeg\"]\n)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"写实风格，一个女生手持枯萎的花站在花园中，镜头逐渐拉远，记录下花园的全貌。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=Image.open(\"data/examples/wan/first_frame.jpeg\").resize((960, 960)),\n    end_image=Image.open(\"data/examples/wan/last_frame.jpeg\").resize((960, 960)),\n    seed=0, tiled=True,\n    height=960, width=960, num_frames=33,\n    sigma_shift=16,\n)\nsave_video(video, \"video_Wan2.1-FLF2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/control_video.mp4\"\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/control_video.mp4\"\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# Image-to-video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# Image-to-video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True,\n    height=720, width=1280,\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_1_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n\n# Video-to-video\nvideo = VideoData(\"video_1_Wan2.1-T2V-1.3B.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，戴着黑色墨镜，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_video=video, denoising_strength=0.7,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"一名宇航员身穿太空服，面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方，点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健，扬起微弱的尘埃，展现出未来科技与原始探索的完美结合。宇航员手持操控装置，目光坚定，仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球，画面既科幻又充满希望，让人不禁畅想未来的星际生活。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_Wan2.1-T2V-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download, snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=\"data/examples/wan/animate/*\",\n)\n\n# Animate\ninput_image = Image.open(\"data/examples/wan/animate/animate_input_image.png\")\nanimate_pose_video = VideoData(\"data/examples/wan/animate/animate_pose_video.mp4\").raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/animate_face_video.mp4\").raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    num_frames=81, height=720, width=1280,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_1_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)\n\n# Replace\nsnapshot_download(\"Wan-AI/Wan2.2-Animate-14B\", allow_file_pattern=\"relighting_lora.ckpt\", local_dir=\"models/Wan-AI/Wan2.2-Animate-14B\")\nlora_state_dict = load_state_dict(\"models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt\", torch_dtype=torch.bfloat16, device=\"cuda\")[\"state_dict\"]\nlora_state_dict = {i: lora_state_dict[i].to(torch.bfloat16) for i in lora_state_dict}\npipe.load_lora(pipe.dit, state_dict=lora_state_dict)\ninput_image = Image.open(\"data/examples/wan/animate/replace_input_image.png\")\nanimate_pose_video = VideoData(\"data/examples/wan/animate/replace_pose_video.mp4\").raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/replace_face_video.mp4\").raw_data()[:81-4]\nanimate_inpaint_video = VideoData(\"data/examples/wan/animate/replace_inpaint_video.mp4\").raw_data()[:81-4]\nanimate_mask_video = VideoData(\"data/examples/wan/animate/replace_mask_video.mp4\").raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    animate_inpaint_video=animate_inpaint_video,\n    animate_mask_video=animate_mask_video,\n    num_frames=81, height=720, width=1280,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_2_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video,VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\ninput_image = Image.open(\"data/examples/wan/input_image.jpg\")\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Left\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_left_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)\n\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    camera_control_direction=\"Up\", camera_control_speed=0.01,\n)\nsave_video(video, \"video_up_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video,VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/control_video.mp4\", \"data/examples/wan/reference_image_girl.png\"]\n)\n\n# Control video\ncontrol_video = VideoData(\"data/examples/wan/control_video.mp4\", height=832, width=576)\nreference_image = Image.open(\"data/examples/wan/reference_image_girl.png\").resize((576, 832))\nvideo = pipe(\n    prompt=\"扁平风格动漫，一位长发少女优雅起舞。她五官精致，大眼睛明亮有神，黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=control_video, reference_image=reference_image,\n    height=832, width=576, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom PIL import Image\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=f\"data/examples/wan/input_image.jpg\"\n)\nimage = Image.open(\"data/examples/wan/input_image.jpg\")\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌，白色的浪花拍打着船身，但小船毫不畏惧，坚定地驶向远方。阳光洒在水面上，闪烁着金色的光芒，为这壮丽的场景增添了一抹温暖。镜头拉近，可以看到船上的旗帜迎风飘扬，象征着不屈的精神与冒险的勇气。这段画面充满力量，激励人心，展现了面对挑战时的无畏与执着。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=image,\n    seed=0, tiled=True,\n    # You can input `end_image=xxx` to control the last frame of the video.\n    # The model will automatically generate the dynamic content between `input_image` and `end_image`.\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/cat_fightning.jpg\"]\n)\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480))\n\nvideo = pipe(\n    prompt=\"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    switch_DiT_boundary=0.9,\n)\nsave_video(video, \"video_Wan2.2-I2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B.py",
    "content": "# This script can generate a single video clip.\n# If you need generate long videos, please refer to `Wan2.2-S2V-14B_multi_clips.py`.\nimport torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_video_dataset\",\n    local_dir=\"./data/example_video_dataset\",\n    allow_file_pattern=f\"wans2v/*\"\n)\n\nnum_frames = 81 # 4n+1\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n# s2v audio input, recommend 16kHz sampling rate\naudio_path = 'data/example_video_dataset/wans2v/sing.MP3'\ninput_audio, sample_rate = librosa.load(audio_path, sr=16000)\n\n# Speech-to-video\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_1_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n\n# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.\npose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'\npose_video = VideoData(pose_video_path, height=height, width=width)\n\n# Speech-to-video with pose\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    s2v_pose_video=pose_video,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_2_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py",
    "content": "import torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V\nfrom modelscope import dataset_snapshot_download\n\n\ndef speech_to_video(\n    prompt,\n    input_image,\n    audio_path,\n    negative_prompt=\"\",\n    num_clip=None,\n    audio_sample_rate=16000,\n    pose_video_path=None,\n    infer_frames=80,\n    height=448,\n    width=832,\n    num_inference_steps=40,\n    fps=16, # recommend fixing fps as 16 for s2v\n    motion_frames=73, # hyperparameter of wan2.2-s2v\n    save_path=None,\n):\n    # s2v audio input, recommend 16kHz sampling rate\n    input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate)\n    # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.\n    pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None\n\n    with torch.no_grad():\n        audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(\n            pipe=pipe,\n            input_audio=input_audio,\n            audio_sample_rate=sample_rate,\n            s2v_pose_video=pose_video,\n            num_frames=infer_frames + 1,\n            height=height,\n            width=width,\n            fps=fps,\n        )\n    num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat\n    print(f\"Generating {num_repeat} video clips...\")\n    motion_video = None\n    video = []\n    for r in range(num_repeat):\n        s2v_pose_latents = pose_latents[r] if pose_latents is not None else None\n        current_clip_tensor = pipe(\n            prompt=prompt,\n            input_image=input_image,\n            negative_prompt=negative_prompt,\n            seed=0,\n            num_frames=infer_frames + 1,\n            height=height,\n            width=width,\n            audio_embeds=audio_embeds[r],\n            s2v_pose_latents=s2v_pose_latents,\n            motion_video=motion_video,\n            num_inference_steps=num_inference_steps,\n            output_type=\"floatpoint\",\n        )\n        current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]\n        if r == 0:\n            current_clip_tensor = current_clip_tensor[:,:,3:,:,:]\n            overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])\n            motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()\n        else:\n            overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])\n            motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)\n        current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)\n        video.extend(current_clip_quantized)\n        save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)\n        print(f\"processed the {r+1}th clip of total {num_repeat} clips.\")\n    return video\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_video_dataset\",\n    local_dir=\"./data/example_video_dataset\",\n    allow_file_pattern=f\"wans2v/*\",\n)\n\ninfer_frames = 80  # 4n\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n\nvideo_with_audio = speech_to_video(\n    prompt=prompt,\n    input_image=input_image,\n    audio_path='data/example_video_dataset/wans2v/sing.MP3',\n    negative_prompt=negative_prompt,\n    pose_video_path='data/example_video_dataset/wans2v/pose.mp4',\n    save_path=\"video_full_Wan2.2-S2V-14B.mp4\",\n    infer_frames=infer_frames,\n    height=height,\n    width=width,\n)\n# num_clip means generating only the first n clips with n * infer_frames frames.\nvideo_with_audio_pose = speech_to_video(\n    prompt=prompt,\n    input_image=input_image,\n    audio_path='data/example_video_dataset/wans2v/sing.MP3',\n    negative_prompt=negative_prompt,\n    pose_video_path='data/example_video_dataset/wans2v/pose.mp4',\n    save_path=\"video_clip_2_Wan2.2-S2V-14B.mp4\",\n    num_clip=2\n)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-T2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"Wan2.2_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=704, width=1248,\n    num_frames=121,\n)\nsave_video(video, \"video_1_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n\n# Image-to-video\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/cat_fightning.jpg\"]\n)\ninput_image = Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((1248, 704))\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=704, width=1248,\n    input_image=input_image,\n    num_frames=121,\n)\nsave_video(video, \"video_2_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=[\"data/examples/wan/depth_video.mp4\", \"data/examples/wan/cat_fightning.jpg\"]\n)\n\n# Depth video -> Video\ncontrol_video = VideoData(\"data/examples/wan/depth_video.mp4\", height=480, width=832)\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_1_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n\n# Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_2_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n\n# Depth video + Reference image -> Video\nvideo = pipe(\n    prompt=\"两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=control_video,\n    vace_reference_image=Image.open(\"data/examples/wan/cat_fightning.jpg\").resize((832, 480)),\n    seed=1, tiled=True\n)\nsave_video(video, \"video_3_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"global_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-global/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.\n# *   When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.\n# *   Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.\n# *   The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是韩舞。帧率是7.5000\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=False,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=48,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg\"),\n    wantodance_fps=7.5,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1] + [0] * 148,\n    framewise_decoding=True,\n)\nsave_video(video, \"video_WanToDance-14B-global.mp4\", fps=7.5, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py",
    "content": "import torch, os\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"local_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-local/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.\n# *   If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.\n# *   In `wantodance_keyframes`, frames that are not keyframes should be solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=24,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg\"),\n    wantodance_fps=30,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1],\n)\nsave_video(video, \"video_WanToDance-14B-local.mp4\", fps=30, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_inference_low_vram/krea-realtime-video.py",
    "content": "import torch\nfrom diffsynth.utils.data import save_video\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": \"disk\",\n    \"offload_device\": \"disk\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"krea/krea-realtime-video\", origin_file_pattern=\"krea-realtime-video-14b.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 2,\n)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"a cat sitting on a boat\",\n    num_inference_steps=6, num_frames=81,\n    seed=0, tiled=True,\n    cfg_scale=1,\n    sigma_shift=20,\n)\nsave_video(video, \"video_krea-realtime-video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/LongCat-Video.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/LongCat-Video/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/LongCat-Video \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/LongCat-Video/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LongCat-Video_full\" \\\n  --trainable_models \"dit\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Video-As-Prompt-Wan2.1-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Video-As-Prompt-Wan2.1-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Video-As-Prompt-Wan2.1-14B/metadata.csv \\\n  --data_file_keys \"video,vap_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vap.\" \\\n  --output_path \"./models/train/Video-As-Prompt-Wan2.1-14B_full\" \\\n  --trainable_models \"vap\" \\\n  --extra_inputs \"vap_video,input_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-1.3b-speedcontrol-v1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-1.3b-speedcontrol-v1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-1.3b-speedcontrol-v1/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.motion_controller.\" \\\n  --output_path \"./models/train/Wan2.1-1.3b-speedcontrol-v1_full\" \\\n  --trainable_models \"motion_controller\" \\\n  --extra_inputs \"motion_bucket_id\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-FLF2V-14B-720P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-FLF2V-14B-720P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-FLF2V-14B-720P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-FLF2V-14B-720P_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\" \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-1.3B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-1.3B-Control_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-1.3B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-1.3B-InP_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-14B-Control_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-14B-InP_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video,reference_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-Control_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video,reference_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-InP_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-480P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image\" \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-720P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-720P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-720P/metadata.csv \\\n  --height 720 \\\n  --width 1280 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-720P_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-1.3B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-1.3B_full\" \\\n  --trainable_models \"dit\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-14B_full\" \\\n  --trainable_models \"dit\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-1.3B-Preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B-Preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B-Preview/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth\" \\\n  --learning_rate 5e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-1.3B-Preview_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n# The learning rate is kept consistent with the settings in the original paper\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-1.3B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth\" \\\n  --learning_rate 5e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-1.3B_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n# The learning rate is kept consistent with the settings in the original paper\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 5e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-14B_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n# The learning rate is kept consistent with the settings in the original paper\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Animate-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Animate-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Animate-14B/metadata.csv \\\n  --data_file_keys \"video,animate_pose_video,animate_face_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.animate_adapter.\" \\\n  --output_path \"./models/train/Wan2.2-Animate-14B_full\" \\\n  --trainable_models \"animate_adapter\" \\\n  --extra_inputs \"input_image,animate_pose_video,animate_face_video\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video,reference_image\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"control_video,reference_image\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-InP_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-InP_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image,end_image\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-I2V-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-I2V-A14B_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-I2V-A14B_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-S2V-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-S2V-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-S2V-14B/metadata.csv \\\n  --data_file_keys \"video,input_audio,s2v_pose_video\" \\\n  --height 448 \\\n  --width 832 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth\" \\\n  --audio_processor_path \"Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 1 \\\n  --trainable_models \"dit\" \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-S2V-14B_full\" \\\n  --extra_inputs \"input_image,input_audio,s2v_pose_video\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-T2V-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 0.417 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [875, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.417\n# boundary corresponds to timesteps [0, 875)\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-TI2V-5B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-TI2V-5B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-TI2V-5B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-TI2V-5B_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"input_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-VACE-Fun-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 5e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [900, 1000]\n# The learning rate is kept consistent with the settings in the original paper\n\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 5e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [0, 900]\n# The learning rate is kept consistent with the settings in the original paper\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/WanToDance-14B-global.sh",
    "content": "# 8*H200 required\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/WanToDance-14B-global/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \\\n  --data_file_keys \"video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path\" \\\n  --height 1280 \\\n  --width 720 \\\n  --num_frames 149 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/WanToDance-14B-global_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding\" \\\n  --use_gradient_checkpointing_offload \\\n  --framewise_decoding\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/WanToDance-14B-local.sh",
    "content": "# 8*H200 required\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/WanToDance-14B-local/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \\\n  --data_file_keys \"video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path\" \\\n  --height 1280 \\\n  --width 720 \\\n  --num_frames 149 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/WanToDance-14B-local_full\" \\\n  --trainable_models \"dit\" \\\n  --extra_inputs \"wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/accelerate_config_14B.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/accelerate_config_zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/wanvideo/model_training/full/krea-realtime-video.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/krea-realtime-video/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/krea-realtime-video \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/krea-realtime-video/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/krea-realtime-video_full\" \\\n  --trainable_models \"dit\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/LongCat-Video.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/LongCat-Video/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/LongCat-Video \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/LongCat-Video/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/LongCat-Video_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"adaLN_modulation.1,attn.qkv,attn.proj,cross_attn.q_linear,cross_attn.kv_linear,cross_attn.proj,ffn.w1,ffn.w2,ffn.w3\" \\\n  --lora_rank 32\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Video-As-Prompt-Wan2.1-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Video-As-Prompt-Wan2.1-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Video-As-Prompt-Wan2.1-14B/metadata.csv \\\n  --data_file_keys \"video,vap_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 10 \\\n  --model_id_with_origin_paths \"ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Video-As-Prompt-Wan2.1-14B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vap_video,input_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-1.3b-speedcontrol-v1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-1.3b-speedcontrol-v1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-1.3b-speedcontrol-v1/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-1.3b-speedcontrol-v1_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"motion_bucket_id\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-FLF2V-14B-720P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-FLF2V-14B-720P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-FLF2V-14B-720P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-FLF2V-14B-720P_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-1.3B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-1.3B-Control_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-1.3B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-1.3B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-1.3B-InP_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-14B-Control_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-14B-InP_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control-Camera/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video,reference_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-1.3B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-1.3B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control-Camera/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-Control_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video,reference_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-Fun-V1.1-14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-Fun-V1.1-14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-Fun-V1.1-14B-InP_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-480P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-720P/*\" --local_dir ./data/diffsynth_example_dataset\n\n# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA\n# We tested on 8*80G GPUs\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-720P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-720P/metadata.csv \\\n  --height 720 \\\n  --width 1280 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-720P_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --initialize_model_on_cpu\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-1.3B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-1.3B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-14B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-1.3B-Preview/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B-Preview \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B-Preview/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-1.3B-Preview_lora\" \\\n  --lora_base_model \"vace\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-1.3B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-1.3B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-1.3B_lora\" \\\n  --lora_base_model \"vace\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-VACE-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-VACE-14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.1-VACE-14B_lora\" \\\n  --lora_base_model \"vace\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Animate-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\n# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA\n# We tested on 8*80G GPUs\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Animate-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Animate-14B/metadata.csv \\\n  --data_file_keys \"video,animate_pose_video,animate_face_video\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Animate-14B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,animate_pose_video,animate_face_video\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-Control-Camera/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control-Camera/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,camera_control_direction,camera_control_speed\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-Control/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control_high_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video,reference_image\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-Control/metadata.csv \\\n  --data_file_keys \"video,control_video,reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-Control_low_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"control_video,reference_image\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-Fun-A14B-InP/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-InP_high_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-Fun-A14B-InP/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-Fun-A14B-InP_low_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,end_image\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-I2V-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-I2V-A14B_high_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-I2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-I2V-A14B_low_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900)\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-S2V-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-S2V-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-S2V-14B/metadata.csv \\\n  --data_file_keys \"video,input_audio,s2v_pose_video\" \\\n  --height 448 \\\n  --width 832 \\\n  --num_frames 81 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth\" \\\n  --audio_processor_path \"Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-S2V-14B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image,input_audio,s2v_pose_video\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-T2V-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_high_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 0.417 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [875, 1000]\n\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_low_noise_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.417\n# boundary corresponds to timesteps [0, 875)\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-TI2V-5B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-TI2V-5B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-TI2V-5B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-TI2V-5B_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-VACE-Fun-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora\" \\\n  --lora_base_model \"vace\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0\n# boundary corresponds to timesteps [900, 1000]\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora\" \\\n  --lora_base_model \"vace\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358\n# boundary corresponds to timesteps [0, 900]\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/WanToDance-14B-global.sh",
    "content": "# 8*H200 required\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/WanToDance-14B-global/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \\\n  --data_file_keys \"video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path\" \\\n  --height 1280 \\\n  --width 720 \\\n  --num_frames 149 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/WanToDance-14B-global_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding\" \\\n  --use_gradient_checkpointing_offload \\\n  --framewise_decoding\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/WanToDance-14B-local.sh",
    "content": "# 8*H200 required\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/WanToDance-14B-local/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \\\n  --data_file_keys \"video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path\" \\\n  --height 1280 \\\n  --width 720 \\\n  --num_frames 149 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/WanToDance-14B-local_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/lora/krea-realtime-video.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/krea-realtime-video/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/krea-realtime-video \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/krea-realtime-video/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/krea-realtime-video_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/direct_distill/Wan2.1-T2V-1.3B.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-1.3B_direct_distill/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B_direct_distill \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-1.3B_direct_distill/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 160 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-1.3B_full_distill\" \\\n  --trainable_models \"dit\" \\\n  --task \"direct_distill\" \\\n  --extra_inputs \"seed,rand_device,num_inference_steps,cfg_scale\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/direct_distill/validate.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(\"models/train/Wan2.1-T2V-1.3B_full_distill/epoch-1.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\n\nvideo = pipe(\n    prompt=\"纪实摄影风格画面，一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄，两只耳朵立起，神情专注而欢快。阳光洒在它身上，使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地，偶尔点缀着几朵野花，远处隐约可见蓝天和几片白云。透视感鲜明，捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。\",\n    cfg_scale=1, num_inference_steps=4,\n    seed=0, tiled=True,\n)\nsave_video(video, \"video_distill_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/fp8_training/Wan2.1-I2V-14B-480P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-480P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora_fp8\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --fp8_models \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\""
  },
  {
    "path": "examples/wanvideo/model_training/special/fp8_training/validate.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-I2V-14B-480P_lora_fp8/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/low_vram_training/Wan2.1-I2V-14B-480P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-480P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora_lowvram_cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --task \"sft:data_process\" \\\n  --offload_models \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors\" \\\n  --fp8_models \"Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --use_gradient_checkpointing_offload\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path \"./models/train/Wan2.1-I2V-14B-480P_lora_split_cache\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora_lowvram\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --task \"sft:train\" \\\n  --offload_models \"Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --fp8_models \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors\" \\\n  --use_gradient_checkpointing_offload\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/low_vram_training/validate.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-I2V-14B-480P_lora_lowvram/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/npu_training/Wan2.1-T2V-14B-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-T2V-14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-T2V-14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-T2V-14B_full\" \\\n  --trainable_models \"dit\" \\\n  --initialize_model_on_cpu"
  },
  {
    "path": "examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-T2V-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_high_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 0.417 \\\n  --min_timestep_boundary 0 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [875, 1000]\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-T2V-A14B/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 49 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.2-T2V-A14B_low_noise_full\" \\\n  --trainable_models \"dit\" \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.417 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [0, 875)"
  },
  {
    "path": "examples/wanvideo/model_training/special/npu_training/Wan2.2-VACE-Fun-A14B-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.2-VACE-Fun-A14B/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 0.358 \\\n  --min_timestep_boundary 0 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [900, 1000]\n\n\naccelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.2-VACE-Fun-A14B/metadata.csv \\\n  --data_file_keys \"video,vace_video,vace_reference_image\" \\\n  --height 480 \\\n  --width 832 \\\n  --num_frames 17 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.vace.\" \\\n  --output_path \"./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full\" \\\n  --trainable_models \"vace\" \\\n  --extra_inputs \"vace_video,vace_reference_image\" \\\n  --use_gradient_checkpointing_offload \\\n  --max_timestep_boundary 1 \\\n  --min_timestep_boundary 0.358 \\\n  --initialize_model_on_cpu\n# boundary corresponds to timesteps [0, 900]"
  },
  {
    "path": "examples/wanvideo/model_training/special/split_training/Wan2.1-I2V-14B-480P.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"wanvideo/Wan2.1-I2V-14B-480P/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P \\\n  --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/Wan2.1-I2V-14B-480P/metadata.csv \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 1 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora_split_cache\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --task \"sft:data_process\" \\\n  --offload_models \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors\"\n\naccelerate launch examples/wanvideo/model_training/train.py \\\n  --dataset_base_path \"./models/train/Wan2.1-I2V-14B-480P_lora_split_cache\" \\\n  --height 480 \\\n  --width 832 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Wan2.1-I2V-14B-480P_lora_split\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"q,k,v,o,ffn.0,ffn.2\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"input_image\" \\\n  --task \"sft:train\" \\\n  --offload_models \"Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"\n"
  },
  {
    "path": "examples/wanvideo/model_training/special/split_training/validate.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-I2V-14B-480P_lora_split/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/train.py",
    "content": "import torch, os, argparse, accelerate, warnings\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.core.data.operators import LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass WanTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None, audio_processor_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n        max_timestep_boundary=1.0,\n        min_timestep_boundary=0.0,\n    ):\n        super().__init__()\n        # Warning\n        if not use_gradient_checkpointing:\n            warnings.warn(\"Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.\")\n            use_gradient_checkpointing = True\n        \n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\") if tokenizer_path is None else ModelConfig(tokenizer_path)\n        audio_processor_config = self.parse_path_or_model_id(audio_processor_path)\n        self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n        \n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Store other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        self.max_timestep_boundary = max_timestep_boundary\n        self.min_timestep_boundary = min_timestep_boundary\n        \n    def parse_extra_inputs(self, data, extra_inputs, inputs_shared):\n        for extra_input in extra_inputs:\n            if extra_input == \"input_image\":\n                inputs_shared[\"input_image\"] = data[\"video\"][0]\n            elif extra_input == \"end_image\":\n                inputs_shared[\"end_image\"] = data[\"video\"][-1]\n            elif extra_input == \"reference_image\" or extra_input == \"vace_reference_image\":\n                inputs_shared[extra_input] = data[extra_input][0]\n            else:\n                inputs_shared[extra_input] = data[extra_input]\n        if inputs_shared.get(\"framewise_decoding\", False):\n            # WanToDance global model\n            inputs_shared[\"num_frames\"] = 4 * (len(data[\"video\"]) - 1) + 1\n        return inputs_shared\n    \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_video\": data[\"video\"],\n            \"height\": data[\"video\"][0].size[1],\n            \"width\": data[\"video\"][0].size[0],\n            \"num_frames\": len(data[\"video\"]),\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"tiled\": False,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n            \"cfg_merge\": False,\n            \"vace_scale\": 1,\n            \"max_timestep_boundary\": self.max_timestep_boundary,\n            \"min_timestep_boundary\": self.min_timestep_boundary,\n        }\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef wan_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_video_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--audio_processor_path\", type=str, default=None, help=\"Path to the audio processor. If provided, the processor will be used for Wan2.2-S2V model.\")\n    parser.add_argument(\"--max_timestep_boundary\", type=float, default=1.0, help=\"Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).\")\n    parser.add_argument(\"--min_timestep_boundary\", type=float, default=0.0, help=\"Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).\")\n    parser.add_argument(\"--initialize_model_on_cpu\", default=False, action=\"store_true\", help=\"Whether to initialize models on CPU.\")\n    parser.add_argument(\"--framewise_decoding\", default=False, action=\"store_true\", help=\"Enable it if this model is a WanToDance global model.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = wan_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_video_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n            num_frames=args.num_frames,\n            time_division_factor=4 if not args.framewise_decoding else 1,\n            time_division_remainder=1 if not args.framewise_decoding else 0,\n        ),\n        special_operator_map={\n            \"animate_face_video\": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),\n            \"input_audio\": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),\n            \"wantodance_music_path\": ToAbsolutePath(args.dataset_base_path),\n        }\n    )\n    model = WanTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        audio_processor_path=args.audio_processor_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=\"cpu\" if args.initialize_model_on_cpu else accelerator.device,\n        max_timestep_boundary=args.max_timestep_boundary,\n        min_timestep_boundary=args.min_timestep_boundary,\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/LongCat-Video.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"meituan-longcat/LongCat-Video\", origin_file_pattern=\"dit/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/LongCat-Video_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_LongCat-Video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Video-As-Prompt-Wan2.1-14B_full/epoch-1.safetensors\")\npipe.vap.load_state_dict(state_dict)\n\nref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'\ntarget_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'\n\nimage = Image.open(target_image_path).convert(\"RGB\")\nref_video = VideoData(ref_video_path, height=480, width=832)\nref_frames = [ref_video[i] for i in range(49)]\n\nvap_prompt = \"A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery.\"\nprompt = \"A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent.\"\nnegative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    input_image=image,\n    seed=42, tiled=True,\n    height=480, width=832,\n    num_frames=49,\n    vap_video=ref_frames,\n    vap_prompt=vap_prompt,\n    negative_vap_prompt=negative_prompt,\n)\nsave_video(video, \"video_Video-As-Prompt-Wan2.1-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-1.3b-speedcontrol-v1_full/epoch-1.safetensors\")\npipe.motion_controller.load_state_dict(state_dict)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=50\n)\nsave_video(video, \"video_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-FLF2V-14B-720P_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    end_image=video[80],\n    seed=0, tiled=True,\n    sigma_shift=16,\n)\nsave_video(video, \"video_Wan2.1-FLF2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-1.3B-Control_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-1.3B-InP_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-14B-Control_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-14B-InP_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-1.3B-Control_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-1.3B-InP_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-14B-Control_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-Fun-V1.1-14B-InP_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-I2V-14B-480P_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=720, width=1280)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    height=720, width=1280, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-T2V-1.3B_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-T2V-14B_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-T2V-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-VACE-1.3B-Preview_full/epoch-1.safetensors\")\npipe.vace.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(49)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-VACE-1.3B_full/epoch-1.safetensors\")\npipe.vace.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(49)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.1-VACE-14B_full/epoch-1.safetensors\")\npipe.vace.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(17)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=17,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Animate-14B_full/epoch-1.safetensors\")\npipe.animate_adapter.load_state_dict(state_dict, strict=False)\n\ninput_image = VideoData(\"data/example_video_dataset/animate/animate_output.mp4\", height=480, width=832)[0]\nanimate_pose_video = VideoData(\"data/examples/wan/animate/animate_pose_video.mp4\", height=480, width=832).raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/animate_face_video.mp4\", height=512, width=512).raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    num_frames=81, height=480, width=832,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_full/epoch-1.safetensors\")\npipe.dit2.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-Control_high_noise_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-Control_low_noise_full/epoch-1.safetensors\")\npipe.dit2.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-InP_high_noise_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-Fun-A14B-InP_low_noise_full/epoch-1.safetensors\")\npipe.dit2.load_state_dict(state_dict)\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors\")\npipe.dit2.load_state_dict(state_dict)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    num_frames=49,\n    seed=1, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-I2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n)\n\nstate_dict = load_state_dict(\"models/train/Wan2.2-S2V-14B_full/epoch-0.safetensors\")\npipe.dit.load_state_dict(state_dict, strict=False)\n\n\nnum_frames = 81 # 4n+1\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n# s2v audio input, recommend 16kHz sampling rate\naudio_path = 'data/example_video_dataset/wans2v/sing.MP3'\ninput_audio, sample_rate = librosa.load(audio_path, sr=16000)\n# S2V pose video input \npose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'\npose_video = VideoData(pose_video_path, height=height, width=width)\n\n# Speech-to-video with pose\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    s2v_pose_video=pose_video,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors\")\npipe.dit2.load_state_dict(state_dict)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-T2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"Wan2.2_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    num_frames=49,\n    seed=1, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\", **vram_config),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\", **vram_config),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16, device=\"cpu\")\npipe.vace.load_state_dict(state_dict)\nstate_dict = load_state_dict(\"models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16, device=\"cpu\")\npipe.vace2.load_state_dict(state_dict)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(17)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=17,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-VACE-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.core import load_state_dict\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"global_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\nstate_dict = load_state_dict(\"models/train/WanToDance-14B-global_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-global/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.\n# *   When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.\n# *   Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.\n# *   The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是韩舞。帧率是7.5000\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=False,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=48,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg\"),\n    wantodance_fps=7.5,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1] + [0] * 148,\n    framewise_decoding=True,\n)\nsave_video(video, \"video_WanToDance-14B-global.mp4\", fps=7.5, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py",
    "content": "import torch, os\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\nfrom diffsynth.core import load_state_dict\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"local_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\nstate_dict = load_state_dict(\"models/train/WanToDance-14B-local_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-local/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.\n# *   If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.\n# *   In `wantodance_keyframes`, frames that are not keyframes should be solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=24,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg\"),\n    wantodance_fps=30,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1],\n)\nsave_video(video, \"video_WanToDance-14B-local.mp4\", fps=30, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_full/krea-realtime-video.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"krea/krea-realtime-video\", origin_file_pattern=\"krea-realtime-video-14b.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\nstate_dict = load_state_dict(\"models/train/krea-realtime-video_full/epoch-1.safetensors\")\npipe.dit.load_state_dict(state_dict)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"a cat sitting on a boat\",\n    num_inference_steps=6, num_frames=81,\n    seed=0, tiled=True,\n    cfg_scale=1,\n    sigma_shift=20,\n)\nsave_video(video, \"video_krea-realtime-video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/LongCat-Video.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"meituan-longcat/LongCat-Video\", origin_file_pattern=\"dit/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/LongCat-Video_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_LongCat-Video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"ByteDance/Video-As-Prompt-Wan2.1-14B\", origin_file_pattern=\"transformer/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Video-As-Prompt-Wan2.1-14B_lora/epoch-4.safetensors\", alpha=1)\n\nref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'\ntarget_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'\n\nimage = Image.open(target_image_path).convert(\"RGB\")\nref_video = VideoData(ref_video_path, height=480, width=832)\nref_frames = [ref_video[i] for i in range(49)]\n\nvap_prompt = \"A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery.\"\nprompt = \"A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent.\"\nnegative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    input_image=image,\n    seed=42, tiled=True,\n    height=480, width=832,\n    num_frames=49,\n    vap_video=ref_frames,\n    vap_prompt=vap_prompt,\n    negative_vap_prompt=negative_prompt,\n)\nsave_video(video, \"video_Video-As-Prompt-Wan2.1-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1\", origin_file_pattern=\"model.safetensors\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-1.3b-speedcontrol-v1_lora/epoch-4.safetensors\", alpha=1)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True,\n    motion_bucket_id=50\n)\nsave_video(video, \"video_Wan2.1-1.3b-speedcontrol-v1.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-FLF2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-FLF2V-14B-720P_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    end_image=video[80],\n    seed=0, tiled=True,\n    sigma_shift=16,\n)\nsave_video(video, \"video_Wan2.1-FLF2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-1.3B-Control_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-1.3B-InP_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-14B-Control_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-14B-InP_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-1.3B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-1.3B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control-Camera\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-Control\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-14B-Control_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.1-Fun-V1.1-14B-InP\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-Fun-V1.1-14B-InP_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.1-Fun-V1.1-14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-480P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-I2V-14B-480P_lora/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-480P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-I2V-14B-720P\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=720, width=1280)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    height=720, width=1280, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-I2V-14B-720P.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-T2V-1.3B_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-T2V-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.1-T2V-14B_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-T2V-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"iic/VACE-Wan2.1-1.3B-Preview\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.vace, \"models/train/Wan2.1-VACE-1.3B-Preview_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(49)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-1.3B-Preview.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-1.3B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.vace, \"models/train/Wan2.1-VACE-1.3B_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(49)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-1.3B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-VACE-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.vace, \"models/train/Wan2.1-VACE-14B_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(17)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=17,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.1-VACE-14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-Animate-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/animate/animate_output.mp4\", height=480, width=832)[0]\nanimate_pose_video = VideoData(\"data/examples/wan/animate/animate_pose_video.mp4\", height=480, width=832).raw_data()[:81-4]\nanimate_face_video = VideoData(\"data/examples/wan/animate/animate_face_video.mp4\", height=512, width=512).raw_data()[:81-4]\nvideo = pipe(\n    prompt=\"视频中的人在做动作\",\n    seed=0, tiled=True,\n    input_image=input_image,\n    animate_pose_video=animate_pose_video,\n    animate_face_video=animate_face_video,\n    num_frames=81, height=480, width=832,\n    num_inference_steps=20, cfg_scale=1,\n)\nsave_video(video, \"video_Wan2.2-Animate-14B.mp4\", fps=15, quality=5)"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control-Camera\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.dit2, \"models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0],\n    camera_control_direction=\"Left\", camera_control_speed=0.0,\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control-Camera.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-Control\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-Fun-A14B-Control_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.dit2, \"models/train/Wan2.2-Fun-A14B-Control_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(81)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\n# Control video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    control_video=video, reference_image=reference_image,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-Control.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-Fun-A14B-InP\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-Fun-A14B-InP_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.dit2, \"models/train/Wan2.2-Fun-A14B-InP_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)\n\n# First and last frame to video\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=video[0], end_image=video[80],\n    seed=0, tiled=True\n)\nsave_video(video, \"video_Wan2.2-Fun-A14B-InP.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-I2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.dit2, \"models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    num_frames=49,\n    seed=1, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-I2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py",
    "content": "import torch\nfrom PIL import Image\nimport librosa\nfrom diffsynth.utils.data import VideoData, save_video_with_audio\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n    audio_processor_config=ModelConfig(model_id=\"Wan-AI/Wan2.2-S2V-14B\", origin_file_pattern=\"wav2vec2-large-xlsr-53-english/\"),\n)\n\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-S2V-14B_lora/epoch-4.safetensors\", alpha=1)\n\n\nnum_frames = 81 # 4n+1\nheight = 448\nwidth = 832\n\nprompt = \"a person is singing\"\nnegative_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ninput_image = Image.open(\"data/example_video_dataset/wans2v/pose.png\").convert(\"RGB\").resize((width, height))\n# s2v audio input, recommend 16kHz sampling rate\naudio_path = 'data/example_video_dataset/wans2v/sing.MP3'\ninput_audio, sample_rate = librosa.load(audio_path, sr=16000)\n# Pose video input\npose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'\npose_video = VideoData(pose_video_path, height=height, width=width)\n\n# Speech-to-video with pose\nvideo = pipe(\n    prompt=prompt,\n    input_image=input_image,\n    negative_prompt=negative_prompt,\n    seed=0,\n    num_frames=num_frames,\n    height=height,\n    width=width,\n    audio_sample_rate=sample_rate,\n    input_audio=input_audio,\n    s2v_pose_video=pose_video,\n    num_inference_steps=40,\n)\nsave_video_with_audio(video[1:], \"video_Wan2.2-S2V-14B.mp4\", audio_path, fps=16, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-T2V-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.dit2, \"models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    num_frames=49,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-T2V-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.2-TI2V-5B\", origin_file_pattern=\"Wan2.2_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.dit, \"models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors\", alpha=1)\n\ninput_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    input_image=input_image,\n    num_frames=49,\n    seed=1, tiled=True,\n)\nsave_video(video, \"video_Wan2.2-TI2V-5B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"high_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"low_noise_model/diffusion_pytorch_model*.safetensors\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"PAI/Wan2.2-VACE-Fun-A14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\npipe.load_lora(pipe.vace, \"models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora/epoch-4.safetensors\", alpha=1)\npipe.load_lora(pipe.vace2, \"models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora/epoch-4.safetensors\", alpha=1)\n\nvideo = VideoData(\"data/example_video_dataset/video1_softedge.mp4\", height=480, width=832)\nvideo = [video[i] for i in range(17)]\nreference_image = VideoData(\"data/example_video_dataset/video1.mp4\", height=480, width=832)[0]\n\nvideo = pipe(\n    prompt=\"from sunset to night, a small town, light, house, river\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    vace_video=video, vace_reference_image=reference_image, num_frames=17,\n    seed=1, tiled=True\n)\nsave_video(video, \"video_Wan2.2-VACE-Fun-A14B.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"global_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/WanToDance-14B-global_lora/epoch-4.safetensors\", alpha=1)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-global/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.\n# *   When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.\n# *   Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.\n# *   The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是韩舞。帧率是7.5000\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=False,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=48,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg\"),\n    wantodance_fps=7.5,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1] + [0] * 148,\n    framewise_decoding=True,\n)\nsave_video(video, \"video_WanToDance-14B-global.mp4\", fps=7.5, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py",
    "content": "import torch, os\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\nfrom modelscope import dataset_snapshot_download\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"local_model.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n        ModelConfig(model_id=\"Wan-AI/WanToDance-14B\", origin_file_pattern=\"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-1.3B\", origin_file_pattern=\"google/umt5-xxl/\"),\n)\npipe.load_lora(pipe.dit, \"models/train/WanToDance-14B-global_lora/epoch-4.safetensors\", alpha=1)\ndataset_snapshot_download(\n    \"DiffSynth-Studio/diffsynth_example_dataset\",\n    local_dir=\"data/diffsynth_example_dataset\",\n    allow_file_pattern=\"wanvideo/WanToDance-14B-local/*\"\n)\n# This is a specialized model with the following constraints on its input parameters:\n# *   The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.\n# *   If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.\n# *   The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.\n# *   The width and height of `wantodance_reference_image` must be multiples of 16.\n# *   `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.\n# *   In `wantodance_keyframes`, frames that are not keyframes should be solid black.\n# *   `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.\nwantodance_keyframes = VideoData(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4\")\nwantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]\nvideo = pipe(\n    prompt=\"一个人正在跳舞，舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。\",\n    negative_prompt=\"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\",\n    seed=0, tiled=True,\n    height=1280, width=720, num_frames=149,\n    num_inference_steps=24,\n    wantodance_music_path=\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav\",\n    wantodance_reference_image=Image.open(\"data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg\"),\n    wantodance_fps=30,\n    wantodance_keyframes=wantodance_keyframes,\n    wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                               1],\n)\nsave_video(video, \"video_WanToDance-14B-local.mp4\", fps=30, quality=5)\n"
  },
  {
    "path": "examples/wanvideo/model_training/validate_lora/krea-realtime-video.py",
    "content": "import torch\nfrom PIL import Image\nfrom diffsynth.utils.data import save_video, VideoData\nfrom diffsynth.core import load_state_dict\nfrom diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig\n\n\npipe = WanVideoPipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"krea/krea-realtime-video\", origin_file_pattern=\"krea-realtime-video-14b.safetensors\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"models_t5_umt5-xxl-enc-bf16.pth\"),\n        ModelConfig(model_id=\"Wan-AI/Wan2.1-T2V-14B\", origin_file_pattern=\"Wan2.1_VAE.pth\"),\n    ],\n)\n\npipe.load_lora(pipe.dit, \"models/train/krea-realtime-video_lora/epoch-4.safetensors\", alpha=1)\n\n# Text-to-video\nvideo = pipe(\n    prompt=\"a cat sitting on a boat\",\n    num_inference_steps=6, num_frames=81,\n    seed=0, tiled=True,\n    cfg_scale=1,\n    sigma_shift=20,\n)\nsave_video(video, \"video_krea-realtime-video.mp4\", fps=15, quality=5)\n"
  },
  {
    "path": "examples/z_image/README.md",
    "content": "English Document: https://diffsynth-studio-doc.readthedocs.io/en/latest/Model_Details/Z-Image.html\n\n中文文档：https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Model_Details/Z-Image.html\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py",
    "content": "from diffsynth.pipelines.z_image import (\n    ZImagePipeline, ModelConfig,\n    ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode\n)\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n# Use `vram_config` to enable LoRA hot-loading\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cuda\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\n# Load models\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Z-Image-Omni-Base-i2L\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\n# Load images\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Z-Image-Omni-Base-i2L\",\n    allow_file_pattern=\"assets/style/*\",\n    local_dir=\"data/style_input\"\n)\nimages = [Image.open(f\"data/style_input/assets/style/1/{i}.jpg\") for i in range(6)]\n\n# Image to LoRA\nwith torch.no_grad():\n    embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n    lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\nsave_file(lora, \"lora.safetensors\")\n\n# Generate images\nprompt = \"a cat\"\nnegative_prompt = \"泛黄，发绿，模糊，低分辨率，低质量图像，扭曲的肢体，诡异的外观，丑陋，AI感，噪点，网格感，JPEG压缩条纹，异常的肢体，水印，乱码，意义不明的字符\"\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=0, cfg_scale=7, num_inference_steps=50,\n    positive_only_lora=lora,\n    sigma_shift=8\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Omni-Base.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)\nimage.save(\"image_Z-Image-Omni-Base.jpg\")\n\nimage = Image.open(\"image_Z-Image-Omni-Base.jpg\")\nprompt = \"Change the women's clothes to white cheongsam, keep other content unchanged\"\nimage = pipe(prompt=prompt, edit_image=image, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4)\nimage.save(\"image_edit_Z-Image-Omni-Base.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=\"data/examples/upscale/low_res.png\"\n)\ncontrolnet_image = Image.open(\"data/examples/upscale/low_res.png\").resize((1024, 1024))\nprompt = \"这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性，他展现出时尚而自信的形象。人物拥有精心打理的短发发型，两侧修剪得较短，顶部保留一定长度，呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜，为整体造型增添了潮流感。脸上洋溢着温和友善的笑容，神情放松自然，给人以阳光开朗的印象。他身穿一件经典的牛仔外套，这件单品永不过时，展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调，领口处隐约可见内搭的衣物。照片的背景是典型的城市街景，可以看到模糊的建筑物、街道和行人，营造出繁华都市的氛围。背景经过了恰当的虚化处理，使人物主体更加突出。光线明亮而柔和，可能是白天的自然光，为照片带来清新通透的视觉效果。整张照片构图专业，景深控制得当，完美捕捉了一个现代都市年轻人充满活力和自信的瞬间，展现出积极向上的生活态度。\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_tile.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\n# Control\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1024, 1024))\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n\n# Inpaint\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\ninpaint_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nprompt = \"一只戴着墨镜的猫\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])\nimage.save(\"image_inpaint.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\n# Control\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1024, 1024))\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],\n    num_inference_steps=30,\n)\nimage.save(\"image_control.jpg\")\n\n# Inpaint\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\ninpaint_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nprompt = \"一只戴着墨镜的猫\"\nimage = pipe(\n    prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],\n    num_inference_steps=30,\n)\nimage.save(\"image_inpaint.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-Turbo.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image_Z-Image-Turbo.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image-i2L.py",
    "content": "from diffsynth.pipelines.z_image import (\n    ZImagePipeline, ModelConfig,\n    ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode\n)\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n# Use `vram_config` to enable LoRA hot-loading\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cuda\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cuda\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\n# Load models\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\"),\n        ModelConfig(model_id=\"DiffSynth-Studio/Z-Image-i2L\", origin_file_pattern=\"model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\n# Load images\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Z-Image-i2L\",\n    allow_file_pattern=\"assets/style/*\",\n    local_dir=\"data/Z-Image-i2L_style_input\"\n)\nimages = [Image.open(f\"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg\") for i in range(4)]\n\n# Image to LoRA\nwith torch.no_grad():\n    embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n    lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\nsave_file(lora, \"lora.safetensors\")\n\n# Generate images\nprompt = \"a cat\"\nnegative_prompt = \"泛黄，发绿，模糊，低分辨率，低质量图像，扭曲的肢体，诡异的外观，丑陋，AI感，噪点，网格感，JPEG压缩条纹，异常的肢体，水印，乱码，意义不明的字符\"\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=0, cfg_scale=4, num_inference_steps=50,\n    positive_only_lora=lora,\n    sigma_shift=8\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference/Z-Image.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_Z-Image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py",
    "content": "from diffsynth.pipelines.z_image import (\n    ZImagePipeline, ModelConfig,\n    ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode\n)\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n# Use `vram_config` to enable LoRA hot-loading\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\n# Load models\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Z-Image-Omni-Base-i2L\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# Load images\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Z-Image-Omni-Base-i2L\",\n    allow_file_pattern=\"assets/style/*\",\n    local_dir=\"data/style_input\"\n)\nimages = [Image.open(f\"data/style_input/assets/style/1/{i}.jpg\") for i in range(6)]\n\n# Image to LoRA\nwith torch.no_grad():\n    embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n    lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\nsave_file(lora, \"lora.safetensors\")\n\n# Generate images\nprompt = \"a cat\"\nnegative_prompt = \"泛黄，发绿，模糊，低分辨率，低质量图像，扭曲的肢体，诡异的外观，丑陋，AI感，噪点，网格感，JPEG压缩条纹，异常的肢体，水印，乱码，意义不明的字符\"\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=0, cfg_scale=7, num_inference_steps=50,\n    positive_only_lora=lora,\n    sigma_shift=8\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)\nimage.save(\"image_Z-Image-Omni-Base.jpg\")\n\nimage = Image.open(\"image_Z-Image-Omni-Base.jpg\")\nprompt = \"Change the women's clothes to white cheongsam, keep other content unchanged\"\nimage = pipe(prompt=prompt, edit_image=image, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4)\nimage.save(\"image_edit_Z-Image-Omni-Base.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/examples_in_diffsynth\",\n    local_dir=\"./\",\n    allow_file_pattern=\"data/examples/upscale/low_res.png\"\n)\ncontrolnet_image = Image.open(\"data/examples/upscale/low_res.png\").resize((1024, 1024))\nprompt = \"这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性，他展现出时尚而自信的形象。人物拥有精心打理的短发发型，两侧修剪得较短，顶部保留一定长度，呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜，为整体造型增添了潮流感。脸上洋溢着温和友善的笑容，神情放松自然，给人以阳光开朗的印象。他身穿一件经典的牛仔外套，这件单品永不过时，展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调，领口处隐约可见内搭的衣物。照片的背景是典型的城市街景，可以看到模糊的建筑物、街道和行人，营造出繁华都市的氛围。背景经过了恰当的虚化处理，使人物主体更加突出。光线明亮而柔和，可能是白天的自然光，为照片带来清新通透的视觉效果。整张照片构图专业，景深控制得当，完美捕捉了一个现代都市年轻人充满活力和自信的瞬间，展现出积极向上的生活态度。\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_tile.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# Control\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1024, 1024))\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n\n# Inpaint\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\ninpaint_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nprompt = \"一只戴着墨镜的猫\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])\nimage.save(\"image_inpaint.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom modelscope import dataset_snapshot_download\nfrom PIL import Image\nimport torch\n\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\n\n# Control\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"depth/image_1.jpg\"\n)\ncontrolnet_image = Image.open(\"data/example_image_dataset/depth/image_1.jpg\").resize((1024, 1024))\nprompt = \"精致肖像，水下少女，蓝裙飘逸，发丝轻扬，光影透澈，气泡环绕，面容恬静，细节精致，梦幻唯美。\"\nimage = pipe(\n    prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],\n    num_inference_steps=30,\n)\nimage.save(\"image_control.jpg\")\n\n# Inpaint\ndataset_snapshot_download(\n    dataset_id=\"DiffSynth-Studio/example_image_dataset\",\n    local_dir=\"./data/example_image_dataset\",\n    allow_file_pattern=\"inpaint/*.jpg\"\n)\ninpaint_image = Image.open(\"./data/example_image_dataset/inpaint/image_1.jpg\").convert(\"RGB\").resize((1024, 1024))\ninpaint_mask = Image.open(\"./data/example_image_dataset/inpaint/mask.jpg\").convert(\"RGB\").resize((1024, 1024))\nprompt = \"一只戴着墨镜的猫\"\nimage = pipe(\n    prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],\n    num_inference_steps=30,\n)\nimage.save(\"image_inpaint.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-Turbo.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image-i2L.py",
    "content": "from diffsynth.pipelines.z_image import (\n    ZImagePipeline, ModelConfig,\n    ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode\n)\nfrom modelscope import snapshot_download\nfrom safetensors.torch import save_file\nimport torch\nfrom PIL import Image\n\n# Use `vram_config` to enable LoRA hot-loading\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\n\n# Load models\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"SigLIP2-G384/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/General-Image-Encoders\", origin_file_pattern=\"DINOv3-7B/model.safetensors\", **vram_config),\n        ModelConfig(model_id=\"DiffSynth-Studio/Z-Image-i2L\", origin_file_pattern=\"model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=0,\n)\n\n# Load images\nsnapshot_download(\n    model_id=\"DiffSynth-Studio/Z-Image-i2L\",\n    allow_file_pattern=\"assets/style/*\",\n    local_dir=\"data/Z-Image-i2L_style_input\"\n)\nimages = [Image.open(f\"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg\") for i in range(4)]\n\n# Image to LoRA\nwith torch.no_grad():\n    embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)\n    lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)[\"lora\"]\nsave_file(lora, \"lora.safetensors\")\n\n# Generate images\nprompt = \"a cat\"\nnegative_prompt = \"泛黄，发绿，模糊，低分辨率，低质量图像，扭曲的肢体，诡异的外观，丑陋，AI感，噪点，网格感，JPEG压缩条纹，异常的肢体，水印，乱码，意义不明的字符\"\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    seed=0, cfg_scale=4, num_inference_steps=50,\n    positive_only_lora=lora,\n    sigma_shift=8\n)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_inference_low_vram/Z-Image.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\nvram_config = {\n    \"offload_dtype\": torch.bfloat16,\n    \"offload_device\": \"cpu\",\n    \"onload_dtype\": torch.bfloat16,\n    \"onload_device\": \"cpu\",\n    \"preparing_dtype\": torch.bfloat16,\n    \"preparing_device\": \"cuda\",\n    \"computation_dtype\": torch.bfloat16,\n    \"computation_device\": \"cuda\",\n}\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\", **vram_config),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\", **vram_config),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n    vram_limit=torch.cuda.mem_get_info(\"cuda\")[1] / (1024 ** 3) - 0.5,\n)\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image_Z-Image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image-Omni-Base.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Omni-Base/*\" --local_dir ./data/diffsynth_example_dataset\n\n# This example is tested on 8*A100\n# Text to image training\naccelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Omni-Base_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters \\\n  --dataset_num_workers 8\n\n# Image(s) to image training\n# accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base/metadata.csv \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 400 \\\n#   --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n#   --learning_rate 1e-5 \\\n#   --num_epochs 2 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/Z-Image-Omni-Base_full_edit\" \\\n#   --trainable_models \"dit\" \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters \\\n#   --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.controlnet.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full\" \\\n  --trainable_models \"controlnet\" \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image-Turbo.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo/*\" --local_dir ./data/diffsynth_example_dataset\n\n# This example is tested on 8*A100\naccelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/Z-Image.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\n# This example is tested on 8*A100\naccelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/full/accelerate_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/z_image/model_training/full/accelerate_config_zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image-Omni-Base.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Omni-Base/*\" --local_dir ./data/diffsynth_example_dataset\n\n# Text to image training\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Omni-Base_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --find_unused_parameters \\\n  --dataset_num_workers 8\n\n# Image(s) to image training\n# accelerate launch examples/z_image/model_training/train.py \\\n#   --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base \\\n#   --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Omni-Base/metadata.csv \\\n#   --data_file_keys \"image,edit_image\" \\\n#   --extra_inputs \"edit_image\" \\\n#   --max_pixels 1048576 \\\n#   --dataset_repeat 50 \\\n#   --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n#   --learning_rate 1e-4 \\\n#   --num_epochs 5 \\\n#   --remove_prefix_in_ckpt \"pipe.dit.\" \\\n#   --output_path \"./models/train/Z-Image-Omni-Base_lora_edit\" \\\n#   --lora_base_model \"dit\" \\\n#   --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n#   --lora_rank 32 \\\n#   --use_gradient_checkpointing \\\n#   --find_unused_parameters \\\n#   --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1 \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo-Fun-Controlnet-Union-2.1/metadata.csv \\\n  --data_file_keys \"image,controlnet_image\" \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 100 \\\n  --model_id_with_origin_paths \"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --extra_inputs \"controlnet_image\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image-Turbo.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/lora/Z-Image.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image_lora\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/special/differential_training/Z-Image-Turbo.sh",
    "content": "# Z-Image-Turbo is a distilled model.\n# After training, it loses its distillation-based acceleration capability,\n# leading to degraded generation quality at fewer inference steps.\n# This issue can be mitigated by using a pre-trained LoRA model to assist the training process.\n# https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo_lora_differential\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --preset_lora_path \"models/ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v1.safetensors\" \\\n  --preset_lora_model \"dit\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8\n"
  },
  {
    "path": "examples/z_image/model_training/special/differential_training/validate.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo_lora_differential/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/special/npu_training/Z-Image-Turbo-NPU.sh",
    "content": "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport CPU_AFFINITY_CONF=1\n\nmodelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 400 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-5 \\\n  --num_epochs 2 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo_full\" \\\n  --trainable_models \"dit\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --enable_npu_patch\n"
  },
  {
    "path": "examples/z_image/model_training/special/trajectory_imitation/Z-Image-Turbo.sh",
    "content": "modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include \"z_image/Z-Image-Turbo/*\" --local_dir ./data/diffsynth_example_dataset\n\naccelerate launch examples/z_image/model_training/train.py \\\n  --dataset_base_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo \\\n  --dataset_metadata_path data/diffsynth_example_dataset/z_image/Z-Image-Turbo/metadata.csv \\\n  --max_pixels 1048576 \\\n  --dataset_repeat 50 \\\n  --model_id_with_origin_paths \"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors\" \\\n  --learning_rate 1e-4 \\\n  --num_epochs 5 \\\n  --remove_prefix_in_ckpt \"pipe.dit.\" \\\n  --output_path \"./models/train/Z-Image-Turbo_lora_distill\" \\\n  --lora_base_model \"dit\" \\\n  --lora_target_modules \"to_q,to_k,to_v,to_out.0,w1,w2,w3\" \\\n  --lora_rank 32 \\\n  --lora_checkpoint \"./models/train/Z-Image-Turbo_lora/epoch-4.safetensors\" \\\n  --use_gradient_checkpointing \\\n  --dataset_num_workers 8 \\\n  --task \"trajectory_imitation\" \\\n  --save_steps 10\n"
  },
  {
    "path": "examples/z_image/model_training/special/trajectory_imitation/validate.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo_lora_distill/step-20.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/train.py",
    "content": "import torch, os, argparse, accelerate, copy\nfrom diffsynth.core import UnifiedDataset\nfrom diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom diffsynth.diffusion import *\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass ZImageTrainingModule(DiffusionTrainingModule):\n    def __init__(\n        self,\n        model_paths=None, model_id_with_origin_paths=None,\n        tokenizer_path=None,\n        trainable_models=None,\n        lora_base_model=None, lora_target_modules=\"\", lora_rank=32, lora_checkpoint=None,\n        preset_lora_path=None, preset_lora_model=None,\n        use_gradient_checkpointing=True,\n        use_gradient_checkpointing_offload=False,\n        extra_inputs=None,\n        fp8_models=None,\n        offload_models=None,\n        device=\"cpu\",\n        task=\"sft\",\n        enable_npu_patch=True,\n    ):\n        super().__init__()\n        # Load models\n        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)\n        tokenizer_config = ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\") if tokenizer_path is None else ModelConfig(tokenizer_path)\n        self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)\n        self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)\n\n        # Training mode\n        self.switch_pipe_to_training_mode(\n            self.pipe, trainable_models,\n            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,\n            preset_lora_path, preset_lora_model,\n            task=task,\n        )\n        \n        # Other configs\n        self.use_gradient_checkpointing = use_gradient_checkpointing\n        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload\n        self.extra_inputs = extra_inputs.split(\",\") if extra_inputs is not None else []\n        self.fp8_models = fp8_models\n        self.task = task\n        self.task_to_loss = {\n            \"sft:data_process\": lambda pipe, *args: args,\n            \"direct_distill:data_process\": lambda pipe, *args: args,\n            \"sft\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"sft:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n            \"direct_distill:train\": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),\n        }\n        if task == \"trajectory_imitation\":\n            # This is an experimental feature.\n            # We may remove it in the future.\n            self.loss_fn = TrajectoryImitationLoss()\n            self.task_to_loss[\"trajectory_imitation\"] = self.loss_fn\n            self.pipe_teacher = copy.deepcopy(self.pipe)\n            self.pipe_teacher.requires_grad_(False)\n        \n    def get_pipeline_inputs(self, data):\n        inputs_posi = {\"prompt\": data[\"prompt\"]}\n        inputs_nega = {\"negative_prompt\": \"\"}\n        inputs_shared = {\n            # Assume you are using this pipeline for inference,\n            # please fill in the input parameters.\n            \"input_image\": data[\"image\"],\n            \"height\": data[\"image\"].size[1],\n            \"width\": data[\"image\"].size[0],\n            # Please do not modify the following parameters\n            # unless you clearly know what this will cause.\n            \"cfg_scale\": 1,\n            \"rand_device\": self.pipe.device,\n            \"use_gradient_checkpointing\": self.use_gradient_checkpointing,\n            \"use_gradient_checkpointing_offload\": self.use_gradient_checkpointing_offload,\n        }\n        if self.task == \"trajectory_imitation\":\n            inputs_shared[\"cfg_scale\"] = 2\n            inputs_shared[\"teacher\"] = self.pipe_teacher\n        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)\n        return inputs_shared, inputs_posi, inputs_nega\n    \n    def forward(self, data, inputs=None):\n        if inputs is None: inputs = self.get_pipeline_inputs(data)\n        inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)\n        for unit in self.pipe.units:\n            inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)\n        loss = self.task_to_loss[self.task](self.pipe, *inputs)\n        return loss\n\n\ndef z_image_parser():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser = add_general_config(parser)\n    parser = add_image_size_config(parser)\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"Path to tokenizer.\")\n    parser.add_argument(\"--enable_npu_patch\", default=False, action=\"store_true\", help=\"Whether to use npu fused operator patch to improve performance in NPU.\")\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = z_image_parser()\n    args = parser.parse_args()\n    accelerator = accelerate.Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],\n    )\n    dataset = UnifiedDataset(\n        base_path=args.dataset_base_path,\n        metadata_path=args.dataset_metadata_path,\n        repeat=args.dataset_repeat,\n        data_file_keys=args.data_file_keys.split(\",\"),\n        main_data_operator=UnifiedDataset.default_image_operator(\n            base_path=args.dataset_base_path,\n            max_pixels=args.max_pixels,\n            height=args.height,\n            width=args.width,\n            height_division_factor=16,\n            width_division_factor=16,\n        )\n    )\n    model = ZImageTrainingModule(\n        model_paths=args.model_paths,\n        model_id_with_origin_paths=args.model_id_with_origin_paths,\n        tokenizer_path=args.tokenizer_path,\n        trainable_models=args.trainable_models,\n        lora_base_model=args.lora_base_model,\n        lora_target_modules=args.lora_target_modules,\n        lora_rank=args.lora_rank,\n        lora_checkpoint=args.lora_checkpoint,\n        preset_lora_path=args.preset_lora_path,\n        preset_lora_model=args.preset_lora_model,\n        use_gradient_checkpointing=args.use_gradient_checkpointing,\n        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,\n        extra_inputs=args.extra_inputs,\n        fp8_models=args.fp8_models,\n        offload_models=args.offload_models,\n        task=args.task,\n        device=accelerator.device,\n        enable_npu_patch=args.enable_npu_patch\n    )\n    model_logger = ModelLogger(\n        args.output_path,\n        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,\n    )\n    launcher_map = {\n        \"sft:data_process\": launch_data_process_task,\n        \"direct_distill:data_process\": launch_data_process_task,\n        \"sft\": launch_training_task,\n        \"sft:train\": launch_training_task,\n        \"direct_distill\": launch_training_task,\n        \"direct_distill:train\": launch_training_task,\n        \"trajectory_imitation\": launch_training_task,\n    }\n    launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\nstate_dict = load_state_dict(\"./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4)\nimage.save(\"image.jpg\")\n\n# Edit\n# state_dict = load_state_dict(\"./models/train/Z-Image-Omni-Base_full_edit/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\n# pipe.dit.load_state_dict(state_dict)\n# prompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\n# images = [\n#     Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n#     Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n# ]\n# image = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4, edit_image=images)\n# image.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full/epoch-1.safetensors\")\npipe.controlnet.load_state_dict(state_dict)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/upscale/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])\nimage.save(\"image_tile.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full/epoch-1.safetensors\")\npipe.controlnet.load_state_dict(state_dict)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full/epoch-1.safetensors\")\npipe.controlnet.load_state_dict(state_dict)\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image-Turbo.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/Z-Image-Turbo_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_full/Z-Image.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom diffsynth.core import load_state_dict\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\nstate_dict = load_state_dict(\"./models/train/Z-Image_full/epoch-1.safetensors\", torch_dtype=torch.bfloat16)\npipe.dit.load_state_dict(state_dict)\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Omni-Base\", origin_file_pattern=\"siglip/model.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\n\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4)\nimage.save(\"image.jpg\")\n\n# Edit\n# pipe.load_lora(pipe.dit, \"./models/train/Z-Image-Omni-Base_lora_edit/epoch-4.safetensors\")\n# prompt = \"Change the color of the dress in Figure 1 to the color shown in Figure 2.\"\n# images = [\n#     Image.open(\"data/example_image_dataset/edit/image1.jpg\").resize((1024, 1024)),\n#     Image.open(\"data/example_image_dataset/edit/image_color.jpg\").resize((1024, 1024)),\n# ]\n# image = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=40, cfg_scale=4, edit_image=images)\n# image.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora/epoch-4.safetensors\")\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/upscale/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])\nimage.save(\"image_tile.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora/epoch-4.safetensors\")\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput\nfrom diffsynth import load_state_dict\nfrom PIL import Image\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1\", origin_file_pattern=\"Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora/epoch-4.safetensors\")\n\ncontrolnet_image = Image.open(\"data/example_image_dataset/canny/image_1.jpg\").resize((1024, 1024))\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])\nimage.save(\"image_control.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image-Turbo.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image-Turbo_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\")\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "examples/z_image/model_training/validate_lora/Z-Image.py",
    "content": "from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig\nimport torch\n\n\npipe = ZImagePipeline.from_pretrained(\n    torch_dtype=torch.bfloat16,\n    device=\"cuda\",\n    model_configs=[\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image\", origin_file_pattern=\"transformer/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n        ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n    ],\n    tokenizer_config=ModelConfig(model_id=\"Tongyi-MAI/Z-Image-Turbo\", origin_file_pattern=\"tokenizer/\"),\n)\npipe.load_lora(pipe.dit, \"./models/train/Z-Image_lora/epoch-4.safetensors\")\nprompt = \"a dog\"\nimage = pipe(prompt=prompt, seed=42, rand_device=\"cuda\", num_inference_steps=50, cfg_scale=4)\nimage.save(\"image.jpg\")\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"diffsynth\"\nversion = \"2.0.6\"\ndescription = \"Enjoy the magic of Diffusion models!\"\nauthors = [{name = \"ModelScope Team\"}]\nlicense = {text = \"Apache-2.0\"}\nrequires-python = \">=3.10.1\"\ndependencies = [\n    \"torch>=2.0.0\",\n    \"torchvision\",\n    \"transformers\",\n    \"imageio\",\n    \"imageio[ffmpeg]\",\n    \"safetensors\",\n    \"einops\",\n    \"sentencepiece\",\n    \"protobuf\",\n    \"modelscope\",\n    \"ftfy\",\n    \"pandas\",\n    \"accelerate\",\n    \"peft\",\n    \"datasets\",\n]\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Operating System :: OS Independent\",\n]\n\n[tool.setuptools.packages.find]\nwhere = [\"./\"]\ninclude = [\"diffsynth\", \"diffsynth.*\"]\n\n[project.optional-dependencies]\nnpu_aarch64 = [\n    \"torch==2.7.1\",\n    \"torch-npu==2.7.1\",\n    \"torchvision==0.22.1\"\n]\nnpu = [\n    \"torch==2.7.1+cpu\",\n    \"torch-npu==2.7.1\",\n    \"torchvision==0.22.1+cpu\"\n]\naudio = [\n    \"torchaudio\",\n    \"torchcodec\"\n]\n\n[tool.setuptools]\ninclude-package-data = true\n"
  }
]