[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[codz]\n*$py.class\n\n# C extensions\n*.so\noutputs/\nprompts/\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py.cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n#poetry.toml\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#   pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.\n#   https://pdm-project.org/en/latest/usage/project/#working-with-version-control\n#pdm.lock\n#pdm.toml\n.pdm-python\n.pdm-build/\n\n# pixi\n#   Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.\n#pixi.lock\n#   Pixi creates a virtual environment in the .pixi directory, just like venv module creates one\n#   in the .venv directory. It is recommended not to include this directory in version control.\n.pixi\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.envrc\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# Abstra\n# Abstra is an AI-powered process automation framework.\n# Ignore directories containing user credentials, local state, and settings.\n# Learn more at https://abstra.io/docs\n.abstra/\n\n# Visual Studio Code\n#  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore \n#  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore\n#  and can be added to the global gitignore or merged into this file. However, if you prefer, \n#  you could uncomment the following to ignore the entire vscode folder\n# .vscode/\n\n# Ruff stuff:\n.ruff_cache/\n\n# PyPI configuration file\n.pypirc\n\n# Cursor\n#  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to\n#  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data\n#  refer to https://docs.cursor.com/context/ignore-files\n.cursorignore\n.cursorindexingignore\n\n# Marimo\nmarimo/_static/\nmarimo/_lsp/\n__marimo__/\n\n# Z-Image\nckpts/\n.isort.cfg\n.pre-commit-config.yaml\n*.DS_Store\n\n# Ignore generated images\n/*.png"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\">⚡️- Image<br><sub><sup>An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer</sup></sub></h1>\n\n<div align=\"center\">\n\n[![Official Site](https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage)](https://tongyi-mai.github.io/Z-Image-blog/)&#160;\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint-Z--Image-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image)&#160;\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint-Z--Image--Turbo-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo)&#160;\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Online_Demo-Z--Image-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image)&#160;\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Online_Demo-Z--Image--Turbo-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)&#160;\n[![ModelScope Model](https://img.shields.io/badge/🤖%20Checkpoint-Z--Image-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)&#160;\n[![ModelScope Model](https://img.shields.io/badge/🤖%20Checkpoint-Z--Image--Turbo-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)&#160;\n[![ModelScope Space](https://img.shields.io/badge/🤖%20Online_Demo-Z--Image-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)&#160;\n[![ModelScope Space](https://img.shields.io/badge/🤖%20Online_Demo-Z--Image--Turbo-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster)&#160;\n[![Art Gallery PDF](https://img.shields.io/badge/%F0%9F%96%BC%20Art_Gallery-PDF-ff69b4)](assets/Z-Image-Gallery.pdf)&#160;\n[![Web Art Gallery](https://img.shields.io/badge/%F0%9F%8C%90%20Web_Art_Gallery-online-00bfff)](https://modelscope.cn/studios/Tongyi-MAI/Z-Image-Gallery/summary)&#160;\n<a href=\"https://arxiv.org/abs/2511.22699\" target=\"_blank\"><img src=\"https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv\" height=\"21px\"></a>\n\n\nWelcome to the official repository for the Z-Image（造相）project!\n\n</div>\n\n\n\n## ✨ Z-Image\n\nZ-Image is a powerful and highly efficient image generation model family with **6B** parameters. Currently there are four variants:\n\n- 🚀 **Z-Image-Turbo** – A distilled version of Z-Image that matches or exceeds leading competitors with only **8 NFEs** (Number of Function Evaluations). It offers **⚡️sub-second inference latency⚡️** on enterprise-grade H800 GPUs and fits comfortably within **16G VRAM consumer devices**. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.\n\n- 🎨 **Z-Image** – The foundation model behind Z-Image-Turbo. Z-Image focuses on **high-quality generation**, **rich aesthetics**, **strong diversity**, and **controllability**, well-suited for creative generation, **fine-tuning**, and downstream development. It supports a wide range of artistic styles, effective negative prompting, and high diversity across identities, poses, compositions, and layouts.\n\n- 🧱 **Z-Image-Omni-Base** – The versatile foundation model capable of both **generation and editing tasks**. By releasing this checkpoint, we aim to unlock the full potential for community-driven fine-tuning and custom development, providing the most \"raw\" and diverse starting point for the open-source community.\n\n- ✍️ **Z-Image-Edit** – A variant fine-tuned on Z-Image specifically for image editing tasks. It supports creative image-to-image generation with impressive instruction-following capabilities, allowing for precise edits based on natural language prompts.\n\n### 📣 News\n\n*   **[2026-01-27]** 🔥 **Z-Image is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image). Try our [online demo](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)!\n*   **[2025-12-08]** 🏆 Z-Image-Turbo ranked 8th overall on the **Artificial Analysis Text-to-Image Leaderboard**, making it the 🥇 <strong style=\"color: #FFC300;\">#1 open-source model</strong>! [Check out the full leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image).\n*   **[2025-12-01]** 🎉 Our technical report for Z-Image is now available on [arXiv](https://arxiv.org/abs/2511.22699).\n*   **[2025-11-26]** 🔥 **Z-Image-Turbo is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo). Try our [online demo](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)!\n\n### 📥 Model Zoo\n\n| Model | Pre-Training | SFT | RL | Step | CFG | Task | Visual Quality | Diversity | Fine-Tunability | Hugging Face | ModelScope |\n| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| **Z-Image-Omni-Base** | ✅ | ❌ | ❌ | 50 | ✅ | Gen. / Editing | Medium | High | Easy | *To be released* | *To be released* |\n| **Z-Image** | ✅ | ✅ | ❌ | 50 | ✅ | Gen. | High | Medium | Easy | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint%20-Z--Image-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image) <br> [![Hugging Face Space](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Z--Image-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image) | [![ModelScope Model](https://img.shields.io/badge/🤖%20%20Checkpoint-Z--Image-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image) <br> [![ModelScope Space](https://img.shields.io/badge/%F0%9F%A4%96%20Demo-Z--Image-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster) |\n| **Z-Image-Turbo** | ✅ | ✅ | ✅ | 8 | ❌ | Gen. | Very High | Low | N/A | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint%20-Z--Image--Turbo-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) <br> [![Hugging Face Space](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Z--Image--Turbo-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo) | [![ModelScope Model](https://img.shields.io/badge/🤖%20%20Checkpoint-Z--Image--Turbo-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) <br> [![ModelScope Space](https://img.shields.io/badge/%F0%9F%A4%96%20Demo-Z--Image--Turbo-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster) |\n| **Z-Image-Edit** | ✅ | ✅ | ❌ | 50 | ✅ | Editing | High | Medium | Easy | *To be released* | *To be released* |\n\nThe figure below illustrates at which training stage each model is produced.\n\n![Training Pipeline of Z-Image](assets/training_pipeline.jpg)\n\n### 🖼️ Showcase\n\n📸 **Photorealistic Quality**: **Z-Image-Turbo** delivers strong photorealistic image generation while maintaining excellent aesthetic quality.\n\n![Showcase of Z-Image on Photo-realistic image Generation](assets/showcase_realistic.png)\n\n📖 **Accurate Bilingual Text Rendering**: **Z-Image-Turbo** excels at accurately rendering complex Chinese and English text.\n\n![Showcase of Z-Image on Bilingual Text Rendering](assets/showcase_rendering.png)\n\n💡  **Prompt Enhancing & Reasoning**: Prompt Enhancer empowers the model with reasoning capabilities, enabling it to transcend surface-level descriptions and tap into underlying world knowledge.\n\n![reasoning.jpg](assets/reasoning.png)\n\n🧠 **Creative Image Editing**: **Z-Image-Edit** shows a strong understanding of bilingual editing instructions, enabling imaginative and flexible image transformations.\n\n![Showcase of Z-Image-Edit on Image Editing](assets/showcase_editing.png)\n\n### 🏗️ Model Architecture\nWe adopt a **Scalable Single-Stream DiT** (S3-DiT) architecture. In this setup, text, visual semantic tokens, and image VAE tokens are concatenated at the sequence level to serve as a unified input stream, maximizing parameter efficiency compared to dual-stream approaches.\n\n![Architecture of Z-Image and Z-Image-Edit](assets/architecture.webp)\n\n### 📈 Performance\n\nZ-Image-Turbo's performance has been validated on multiple independent benchmarks, where it consistently demonstrates state-of-the-art results, especially as the leading open-source model.\n\n#### Artificial Analysis Text-to-Image Leaderboard\nOn the highly competitive [Artificial Analysis Leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image), Z-Image-Turbo ranked **8th overall** and secured the top position as the 🥇 <strong style=\"color: gold;\">#1 Open-Source Model</strong>, outperforming all other open-source alternatives.\n\n\n<p align=\"center\">\n  <a href=\"https://artificialanalysis.ai/image/leaderboard/text-to-image\">\n    <img src=\"assets/image_arena_all.jpg\" alt=\"Z-Image Rank on Artificial Analysis Leaderboard\"/><br />\n    <span style=\"font-size:1.05em; cursor:pointer; text-decoration:underline;\"> Artificial Analysis Leaderboard</span>\n  </a>\n</p>\n\n<p align=\"center\">\n  <a href=\"https://artificialanalysis.ai/image/leaderboard/text-to-image\">\n    <img src=\"assets/image_arena_os.jpg\" alt=\"Z-Image Rank on Artificial Analysis Leaderboard (Open-Source Model Only)\"/><br />\n    <span style=\"font-size:1.05em; cursor:pointer; text-decoration:underline;\"> Artificial Analysis Leaderboard (Open-Source Model Only)</span>\n  </a>\n</p>\n\n#### Alibaba AI Arena Text-to-Image Leaderboard\nAccording to the Elo-based Human Preference Evaluation on [*Alibaba AI Arena*](https://aiarena.alibaba-inc.com/corpora/arena/leaderboard?arenaType=T2I), Z-Image-Turbo also achieves state-of-the-art results among open-source models and shows highly competitive performance against leading proprietary models.\n\n<p align=\"center\">\n  <a href=\"https://aiarena.alibaba-inc.com/corpora/arena/leaderboard?arenaType=T2I\">\n    <img src=\"assets/leaderboard.png\" alt=\"Z-Image Elo Rating on AI Arena\"/><br />\n    <span style=\"font-size:1.05em; cursor:pointer; text-decoration:underline;\"> Alibaba AI Arena Text-to-Image Leaderboard</span>\n  </a>\n</p>\n\n\n### 🚀 Quick Start\n#### (1) PyTorch Native Inference\nBuild a virtual environment you like and then install the dependencies:\n```bash\npip install -e .\n```\nThen run the following code to generate an image:\n```bash\npython inference.py\n```\n\n#### (2) Diffusers Inference\nInstall the latest version of diffusers, use the following command:\n<details>\n  <summary>Click here for details for why you need to install diffusers from source</summary>\n\n  We have submitted two pull requests ([#12703](https://github.com/huggingface/diffusers/pull/12703) and [#12715](https://github.com/huggingface/diffusers/pull/12715)) to the 🤗 diffusers repository to add support for Z-Image. Both PRs have been merged into the latest official diffusers release.\n  Therefore, you need to install diffusers from source for the latest features and Z-Image support.\n\n</details>\n\n```bash\npip install git+https://github.com/huggingface/diffusers\n```\n\n<details>\n<summary><b>Z-Image-Turbo</b> - Click to expand</summary>\n\nThen, try the following code to generate an image:\n```python\nimport torch\nfrom diffusers import ZImagePipeline\n\n# 1. Load the pipeline\n# Use bfloat16 for optimal performance on supported GPUs\npipe = ZImagePipeline.from_pretrained(\n    \"Tongyi-MAI/Z-Image-Turbo\",\n    torch_dtype=torch.bfloat16,\n    low_cpu_mem_usage=False,\n)\npipe.to(\"cuda\")\n\n# [Optional] Attention Backend\n# Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported:\n# pipe.transformer.set_attention_backend(\"flash\")    # Enable Flash-Attention-2\n# pipe.transformer.set_attention_backend(\"_flash_3\") # Enable Flash-Attention-3\n\n# [Optional] Model Compilation\n# Compiling the DiT model accelerates inference, but the first run will take longer to compile.\n# pipe.transformer.compile()\n\n# [Optional] CPU Offloading\n# Enable CPU offloading for memory-constrained devices.\n# pipe.enable_model_cpu_offload()\n\nprompt = \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\n\n# 2. Generate Image\nimage = pipe(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    num_inference_steps=9,  # This actually results in 8 DiT forwards\n    guidance_scale=0.0,     # Guidance should be 0 for the Turbo models\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).images[0]\n\nimage.save(\"example.png\")\n```\n\n</details>\n\n<details>\n<summary><b>Z-Image</b> - Click to expand</summary>\n\nRecommended Parameters:\n- **Resolution:** 512×512 to 2048×2048 (total pixel area, any aspect ratio)\n- **Guidance scale:** 3.0 – 5.0\n- **Inference steps:** 28 – 50\n- **Negative prompts:** Strongly recommended for better control\n- **CFG normalization:** `False` for general stylism, `True` for realism\n\nThen, try the following code to generate an image:\n```python\nimport torch\nfrom diffusers import ZImagePipeline\n\n# Load the pipeline\npipe = ZImagePipeline.from_pretrained(\n    \"Tongyi-MAI/Z-Image\",\n    torch_dtype=torch.bfloat16,\n    low_cpu_mem_usage=False,\n)\npipe.to(\"cuda\")\n\n# Generate image\nprompt = \"两名年轻亚裔女性紧密站在一起，背景为朴素的灰色纹理墙面，可能是室内地毯地面。左侧女性留着长卷发，身穿藏青色毛衣，左袖有奶油色褶皱装饰，内搭白色立领衬衫，下身白色裤子；佩戴小巧金色耳钉，双臂交叉于背后。右侧女性留直肩长发，身穿奶油色卫衣，胸前印有\"Tun the tables\"字样，下方为\"New ideas\"，搭配白色裤子；佩戴银色小环耳环，双臂交叉于胸前。两人均面带微笑直视镜头。照片，自然光照明，柔和阴影，以藏青、奶油白为主的中性色调，休闲时尚摄影，中等景深，面部和上半身对焦清晰，姿态放松，表情友好，室内环境，地毯地面，纯色背景。\"\nnegative_prompt = \"\" # Optional, but would be powerful when you want to remove some unwanted content\n\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=1280,\n    width=720,\n    cfg_normalization=False,\n    num_inference_steps=50,\n    guidance_scale=4,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).images[0]\n\nimage.save(\"example.png\")\n```\n\n</details>\n\n## 🔬 Decoupled-DMD: The Acceleration Magic Behind Z-Image\n\n[![arXiv](https://img.shields.io/badge/arXiv-2511.22677-b31b1b.svg)](https://arxiv.org/abs/2511.22677)\n\nDecoupled-DMD is the core few-step distillation algorithm that empowers the 8-step Z-Image model.\n\nOur core insight in Decoupled-DMD  is that the success of existing DMD (Distribution Matching Distillation) methods is the result of two independent, collaborating mechanisms:\n\n-   **CFG Augmentation (CA)**: The primary **engine** 🚀 driving the distillation process, a factor largely overlooked in previous work.\n-   **Distribution Matching (DM)**: Acts more as a **regularizer** ⚖️, ensuring the stability and quality of the generated output.\n\nBy recognizing and decoupling these two mechanisms, we were able to study and optimize them in isolation. This ultimately motivated us to develop an improved distillation process that significantly enhances the performance of few-step generation.\n\n![Diagram of Decoupled-DMD](assets/decoupled-dmd.webp)\n\n## 🤖 DMDR: Fusing DMD with Reinforcement Learning\n\n[![arXiv](https://img.shields.io/badge/arXiv-2511.13649-b31b1b.svg)](https://arxiv.org/abs/2511.13649)\n\nBuilding upon the strong foundation of Decoupled-DMD, our 8-step Z-Image model has already demonstrated exceptional capabilities. To achieve further improvements in terms of semantic alignment, aesthetic quality, and structural coherence—while producing images with richer high-frequency details—we present **DMDR**.\n\nOur core insight behind DMDR is that Reinforcement Learning (RL) and Distribution Matching Distillation (DMD) can be synergistically integrated during the post-training of few-step models. We demonstrate that:\n\n-   **RL Unlocks the Performance of DMD** 🚀\n-   **DMD Effectively Regularizes RL** ⚖️\n\n![Diagram of DMDR](assets/DMDR.webp)\n\n## 🎉 Community Works\n\n- [Cache-DiT](https://github.com/vipshop/cache-dit) provides inference acceleration for **Z-Image** and **Z-Image-ControlNet** via DBCache, Context Parallelism and Tensor Parallelism. It achieves nearly **4x** speedup on 4 GPUs with negligible precision loss. Please visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples) for more details.\n- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).\n- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).\n- [LeMiCa](https://github.com/UnicomAI/LeMiCa) provides a training-free, timestep-level acceleration method that conveniently speeds up Z-Image inference. For more details, see [LeMiCa4Z-Image](https://github.com/UnicomAI/LeMiCa/tree/main/LeMiCa4Z-Image).\n- [ComfyUI ZImageLatent](https://github.com/HellerCommaA/ComfyUI-ZImageLatent) provdes an easy to use latent of the official Z-Image resolutions.\n- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) has provided more support for Z-Image, including LoRA training, full training, distillation training, and low-VRAM inference. Please refer to the [document](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Model_Details/Z-Image.md) of DiffSynth-Studio.\n- [vllm-omni](https://github.com/vllm-project/vllm-omni), a framework that extends its support for omni-modality model fast inference and serving, now [supports](https://github.com/vllm-project/vllm-omni/blob/main/docs/models/supported_models.md) Z-Image.\n- [SGLang-Diffusion](https://lmsys.org/blog/2025-11-07-sglang-diffusion/) brings SGLang's state-of-the-art performance to accelerate image and video generation for diffusion models, now [supporting](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py) Z-Image.\n- [Candle](https://github.com/huggingface/candle) is a minimalist machine learning (ML) framework launched by Huggingface for Rust, which now [supports](https://github.com/huggingface/candle/pull/3261) Z-Image.\n- [MeanCache](https://github.com/UnicomAI/MeanCache), a training-free inference acceleration method for Flow Matching models by China Unicom Data Science and Artificial Intelligence Research Institute. Delivers up to **3.7x** speedup for **Z-Image** generation with plug-and-play integration while preserving output quality.\n\n## 🚀 Star History\n\n[![Star History Chart](https://api.star-history.com/svg?repos=Tongyi-MAI/Z-Image&type=date&legend=top-left)](https://www.star-history.com/#Tongyi-MAI/Z-Image&type=date&legend=top-left)\n\n\n## 📜 Citation\n\nIf you find our work useful in your research, please consider citing:\n\n```bibtex\n@article{team2025zimage,\n  title={Z-Image: An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer},\n  author={Z-Image Team},\n  journal={arXiv preprint arXiv:2511.22699},\n  year={2025}\n}\n\n@article{liu2025decoupled,\n  title={Decoupled DMD: CFG Augmentation as the Spear, Distribution Matching as the Shield},\n  author={Dongyang Liu and Peng Gao and David Liu and Ruoyi Du and Zhen Li and Qilong Wu and Xin Jin and Sihan Cao and Shifeng Zhang and Hongsheng Li and Steven Hoi},\n  journal={arXiv preprint arXiv:2511.22677},\n  year={2025}\n}\n\n@article{jiang2025distribution,\n  title={Distribution Matching Distillation Meets Reinforcement Learning},\n  author={Jiang, Dengyang and Liu, Dongyang and Wang, Zanyi and Wu, Qilong and Jin, Xin and Liu, David and Li, Zhen and Wang, Mengmeng and Gao, Peng and Yang, Harry},\n  journal={arXiv preprint arXiv:2511.13649},\n  year={2025}\n}\n\n```\n\n## 🤝 We're Hiring!\n\nWe're actively looking for **Research Scientists**, **Engineers**, and **Interns** to work on foundational generative models and their applications. Interested candidates please send your resume to: **jingpeng.gp@alibaba-inc.com**\n"
  },
  {
    "path": "batch_inference.py",
    "content": "\"\"\"Batch prompt inference for Z-Image.\"\"\"\n\nimport os\nfrom pathlib import Path\nimport time\n\nimport torch\n\nfrom inference import ensure_weights\nfrom utils import AttentionBackend, load_from_local_dir, set_attention_backend\nfrom zimage import generate\n\n\ndef read_prompts(path: str) -> list[str]:\n    \"\"\"Read prompts from a text file (one per line, empty lines skipped).\"\"\"\n\n    prompt_path = Path(path)\n    if not prompt_path.exists():\n        raise FileNotFoundError(f\"Prompt file not found: {prompt_path}\")\n    with prompt_path.open(\"r\", encoding=\"utf-8\") as f:\n        prompts = [line.strip() for line in f if line.strip()]\n    if not prompts:\n        raise ValueError(f\"No prompts found in {prompt_path}\")\n    return prompts\n\n\nPROMPTS = read_prompts(os.environ.get(\"PROMPTS_FILE\", \"prompts/prompt1.txt\"))\n\n\ndef slugify(text: str, max_len: int = 60) -> str:\n    \"\"\"Create a filesystem-safe slug from the prompt.\"\"\"\n\n    slug = \"\".join(ch.lower() if ch.isalnum() else \"-\" for ch in text)\n    slug = \"-\".join(part for part in slug.split(\"-\") if part)\n    return slug[:max_len].rstrip(\"-\") or \"prompt\"\n\n\ndef select_device() -> str:\n    \"\"\"Choose the best available device without repeating detection logic.\"\"\"\n\n    if torch.cuda.is_available():\n        print(\"Chosen device: cuda\")\n        return \"cuda\"\n    try:\n        import torch_xla.core.xla_model as xm\n\n        device = xm.xla_device()\n        print(\"Chosen device: tpu\")\n        return device\n    except (ImportError, RuntimeError):\n        if torch.backends.mps.is_available():\n            print(\"Chosen device: mps\")\n            return \"mps\"\n        print(\"Chosen device: cpu\")\n        return \"cpu\"\n\n\ndef main():\n    model_path = ensure_weights(\"ckpts/Z-Image-Turbo\")\n    dtype = torch.bfloat16\n    compile = False\n    height = 1024\n    width = 1024\n    num_inference_steps = 8\n    guidance_scale = 0.0\n    attn_backend = os.environ.get(\"ZIMAGE_ATTENTION\", \"_native_flash\")\n    output_dir = Path(\"outputs\")\n    output_dir.mkdir(exist_ok=True)\n\n    device = select_device()\n\n    components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)\n    AttentionBackend.print_available_backends()\n    set_attention_backend(attn_backend)\n    print(f\"Chosen attention backend: {attn_backend}\")\n\n    for idx, prompt in enumerate(PROMPTS, start=1):\n        output_path = output_dir / f\"prompt-{idx:02d}-{slugify(prompt)}.png\"\n        seed = 42 + idx - 1\n        generator = torch.Generator(device).manual_seed(seed)\n\n        start_time = time.time()\n        images = generate(\n            prompt=prompt,\n            **components,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            generator=generator,\n        )\n        elapsed = time.time() - start_time\n        images[0].save(output_path)\n        print(f\"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds\")\n\n    print(\"Done.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "inference.py",
    "content": "\"\"\"Z-Image PyTorch Native Inference.\"\"\"\n\nimport os\nimport time\nimport warnings\n\nimport torch\n\nwarnings.filterwarnings(\"ignore\")\nfrom utils import AttentionBackend, ensure_model_weights, load_from_local_dir, set_attention_backend\nfrom zimage import generate\n\n\ndef main():\n    model_path = ensure_model_weights(\"ckpts/Z-Image-Turbo\", verify=False)  # True to verify with md5\n    dtype = torch.bfloat16\n    compile = False  # default False for compatibility\n    output_path = \"example.png\"\n    height = 1024\n    width = 1024\n    num_inference_steps = 8\n    guidance_scale = 0.0\n    seed = 42\n    attn_backend = os.environ.get(\"ZIMAGE_ATTENTION\", \"_native_flash\")\n    prompt = (\n        \"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. \"\n        \"Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. \"\n        \"Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, \"\n        \"silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.\"\n    )\n\n    # Device selection priority: cuda -> tpu -> mps -> cpu\n    if torch.cuda.is_available():\n        device = \"cuda\"\n        print(\"Chosen device: cuda\")\n    else:\n        try:\n            import torch_xla\n            import torch_xla.core.xla_model as xm\n\n            device = xm.xla_device()\n            print(\"Chosen device: tpu\")\n        except (ImportError, RuntimeError):\n            if torch.backends.mps.is_available():\n                device = \"mps\"\n                print(\"Chosen device: mps\")\n            else:\n                device = \"cpu\"\n                print(\"Chosen device: cpu\")\n    # Load models\n    components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)\n    AttentionBackend.print_available_backends()\n    set_attention_backend(attn_backend)\n    print(f\"Chosen attention backend: {attn_backend}\")\n\n    # Gen an image\n    start_time = time.time()\n    images = generate(\n        prompt=prompt,\n        **components,\n        height=height,\n        width=width,\n        num_inference_steps=num_inference_steps,\n        guidance_scale=guidance_scale,\n        generator=torch.Generator(device).manual_seed(seed),\n    )\n    end_time = time.time()\n    print(f\"Time taken: {end_time - start_time:.2f} seconds\")\n    images[0].save(output_path)\n\n    ### !! For best speed performance, recommend to use `_flash_3` backend and set `compile=True`\n    ### This would give you sub-second generation speed on Hopper GPU (H100/H200/H800) after warm-up\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"zimage-native\"\nversion = \"0.1.0\"\ndescription = \"Z-Image PyTorch Native Implementation\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"torch>=2.5.0\",\n    \"transformers>=4.51.0\",\n    \"safetensors\",\n    \"loguru\",\n    \"pillow\",\n    \"accelerate\",\n    \"huggingface_hub>=0.25.0\"\n]\n\n[project.optional-dependencies]\ndev = [\n    \"black\",\n    \"isort\",\n    \"ruff\"\n]\n\n[tool.setuptools.packages.find]\nwhere = [\"src\"]\n"
  },
  {
    "path": "src/__init__.py",
    "content": "\"\"\"Z-Image Native Implementation.\"\"\"\n\nfrom .utils import load_from_local_dir\nfrom .zimage import ZImageTransformer2DModel, generate\n\n__version__ = \"0.1.0\"\n\n__all__ = [\n    \"ZImageTransformer2DModel\",\n    \"generate\",\n    \"load_from_local_dir\",\n]\n"
  },
  {
    "path": "src/config/__init__.py",
    "content": "\"\"\"Z-Image Configuration.\"\"\"\n\nfrom .inference import (\n    DEFAULT_CFG_TRUNCATION,\n    DEFAULT_GUIDANCE_SCALE,\n    DEFAULT_HEIGHT,\n    DEFAULT_INFERENCE_STEPS,\n    DEFAULT_MAX_SEQUENCE_LENGTH,\n    DEFAULT_WIDTH,\n)\nfrom .model import (\n    ADALN_EMBED_DIM,\n    BASE_IMAGE_SEQ_LEN,\n    BASE_SHIFT,\n    BYTES_PER_GB,\n    DEFAULT_LOAD_DEVICE,\n    DEFAULT_LOAD_DTYPE_STR,\n    DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,\n    DEFAULT_SCHEDULER_SHIFT,\n    DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,\n    DEFAULT_TRANSFORMER_CAP_FEAT_DIM,\n    DEFAULT_TRANSFORMER_DIM,\n    DEFAULT_TRANSFORMER_F_PATCH_SIZE,\n    DEFAULT_TRANSFORMER_IN_CHANNELS,\n    DEFAULT_TRANSFORMER_N_HEADS,\n    DEFAULT_TRANSFORMER_N_KV_HEADS,\n    DEFAULT_TRANSFORMER_N_LAYERS,\n    DEFAULT_TRANSFORMER_N_REFINER_LAYERS,\n    DEFAULT_TRANSFORMER_NORM_EPS,\n    DEFAULT_TRANSFORMER_PATCH_SIZE,\n    DEFAULT_TRANSFORMER_QK_NORM,\n    DEFAULT_TRANSFORMER_T_SCALE,\n    DEFAULT_VAE_IN_CHANNELS,\n    DEFAULT_VAE_LATENT_CHANNELS,\n    DEFAULT_VAE_NORM_NUM_GROUPS,\n    DEFAULT_VAE_OUT_CHANNELS,\n    DEFAULT_VAE_SCALE_FACTOR,\n    DEFAULT_VAE_SCALING_FACTOR,\n    FREQUENCY_EMBEDDING_SIZE,\n    MAX_IMAGE_SEQ_LEN,\n    MAX_PERIOD,\n    MAX_SHIFT,\n    ROPE_AXES_DIMS,\n    ROPE_AXES_LENS,\n    ROPE_THETA,\n    SEQ_MULTI_OF,\n)\n\n__all__ = [\n    \"ADALN_EMBED_DIM\",\n    \"SEQ_MULTI_OF\",\n    \"ROPE_THETA\",\n    \"ROPE_AXES_DIMS\",\n    \"ROPE_AXES_LENS\",\n    \"FREQUENCY_EMBEDDING_SIZE\",\n    \"MAX_PERIOD\",\n    \"BASE_IMAGE_SEQ_LEN\",\n    \"MAX_IMAGE_SEQ_LEN\",\n    \"BASE_SHIFT\",\n    \"MAX_SHIFT\",\n    \"DEFAULT_VAE_SCALE_FACTOR\",\n    \"DEFAULT_VAE_IN_CHANNELS\",\n    \"DEFAULT_VAE_OUT_CHANNELS\",\n    \"DEFAULT_VAE_LATENT_CHANNELS\",\n    \"DEFAULT_VAE_NORM_NUM_GROUPS\",\n    \"DEFAULT_VAE_SCALING_FACTOR\",\n    \"DEFAULT_TRANSFORMER_PATCH_SIZE\",\n    \"DEFAULT_TRANSFORMER_F_PATCH_SIZE\",\n    \"DEFAULT_TRANSFORMER_IN_CHANNELS\",\n    \"DEFAULT_TRANSFORMER_DIM\",\n    \"DEFAULT_TRANSFORMER_N_LAYERS\",\n    \"DEFAULT_TRANSFORMER_N_REFINER_LAYERS\",\n    \"DEFAULT_TRANSFORMER_N_HEADS\",\n    \"DEFAULT_TRANSFORMER_N_KV_HEADS\",\n    \"DEFAULT_TRANSFORMER_NORM_EPS\",\n    \"DEFAULT_TRANSFORMER_QK_NORM\",\n    \"DEFAULT_TRANSFORMER_CAP_FEAT_DIM\",\n    \"DEFAULT_TRANSFORMER_T_SCALE\",\n    \"DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS\",\n    \"DEFAULT_SCHEDULER_SHIFT\",\n    \"DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING\",\n    \"DEFAULT_LOAD_DEVICE\",\n    \"DEFAULT_LOAD_DTYPE_STR\",\n    \"BYTES_PER_GB\",\n    \"DEFAULT_HEIGHT\",\n    \"DEFAULT_WIDTH\",\n    \"DEFAULT_INFERENCE_STEPS\",\n    \"DEFAULT_GUIDANCE_SCALE\",\n    \"DEFAULT_CFG_TRUNCATION\",\n    \"DEFAULT_MAX_SEQUENCE_LENGTH\",\n]\n"
  },
  {
    "path": "src/config/inference.py",
    "content": "\"\"\"Inference-specific configuration for Z-Image.\"\"\"\n\nDEFAULT_HEIGHT = 1024\nDEFAULT_WIDTH = 1024\nDEFAULT_INFERENCE_STEPS = 8\nDEFAULT_GUIDANCE_SCALE = 0.0\nDEFAULT_CFG_TRUNCATION = 1.0\nDEFAULT_MAX_SEQUENCE_LENGTH = 512\n"
  },
  {
    "path": "src/config/manifests/README.md",
    "content": "# Model Manifests\n\nThis directory contains manifest files for different Z-Image model variants.\n\n## Purpose\n\nManifest files list all required files for each model, optionally with MD5 checksums for integrity verification.\n\n## File Naming Convention\n\n- `z-image-turbo.txt` - Z-Image Turbo model\n- Custom models: `{model-name}.txt`\n\n## Format\n\n### Standard Format (with MD5 - Recommended)\n\n```txt\n# Z-Image Model Manifest\n# Format: <md5hash>  <filepath>\n# Generated automatically - DO NOT edit manually\n\n5e3226ed72a9a4a080f2a4ca78b98ddc  model_index.json\nca682fcc6c5a94cf726b7187e64b9411  scheduler/scheduler_config.json\n1e97eb35d9d0b6aa60c58a8df8d7d99a  text_encoder/config.json\n30b85686b9a9b002e012494fadc027cb  text_encoder/model-00001-of-00003.safetensors\n...\n```\n\n**Verification Behavior:**\n- `verify=False`: Default, only checks file existence, ignores MD5 (fast)\n- `verify=True`: Checks existence AND verifies MD5 checksums (thorough)\n\n## Usage\n\nThe manifest file is automatically selected based on the model directory name:\n\n```python\n# Auto-detects manifest from \"Z-Image-Turbo\" -> uses z-image-turbo.txt\nmodel_path = ensure_model_weights(\"ckpts/Z-Image-Turbo\")\n\n# Explicit manifest\nmodel_path = ensure_model_weights(\"ckpts/Z-Image-Turbo\", manifest_name=\"z-image-turbo.txt\")\n```\n\n## Generating Manifests\n\nUse the provided tool to generate manifests:\n\n```bash\n# Generate with MD5 checksums (auto-saves to this directory)\npython -m src.tools.generate_manifest ckpts/Z-Image-Turbo\n\n# Generate without checksums (faster, not recommended)\npython -m src.tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums\n\n# With verbose output\npython -m src.tools.generate_manifest ckpts/Z-Image-Turbo --verbose\n\n# Custom output path\npython -m src.tools.generate_manifest ckpts/Z-Image-Turbo --output custom.txt\n```\n\n## Available Manifests\n\n- **z-image-turbo.txt** - Z-Image Turbo model\n"
  },
  {
    "path": "src/config/manifests/z-image-turbo.txt",
    "content": "# Z-Image Model Manifest\n# Format: <md5hash>  <filepath>\n# Generated automatically - DO NOT edit manually\n\n5e3226ed72a9a4a080f2a4ca78b98ddc  model_index.json\nca682fcc6c5a94cf726b7187e64b9411  scheduler/scheduler_config.json\n1e97eb35d9d0b6aa60c58a8df8d7d99a  text_encoder/config.json\n30b85686b9a9b002e012494fadc027cb  text_encoder/model-00001-of-00003.safetensors\ne6a24ea164404a01ad2800dbae4e1a13  text_encoder/model-00002-of-00003.safetensors\n09e190ed15ff14795b6277e023cfcb2d  text_encoder/model-00003-of-00003.safetensors\n589f5395156900f49d617aee8a8d8708  text_encoder/model.safetensors.index.json\n6423133b9cc1a2077b57822c30c211aa  tokenizer/tokenizer.json\nb06e103ac555ec4b51266078b518c0f0  tokenizer/tokenizer_config.json\nbaed87136fe5f848e24b072f99856cc3  transformer/config.json\n54889d0dd179b4fa2fd7bd0e487d856e  transformer/diffusion_pytorch_model-00001-of-00003.safetensors\nfe81e804658d345323512c63224b0604  transformer/diffusion_pytorch_model-00002-of-00003.safetensors\n4e074e09129f98ad840414951f122feb  transformer/diffusion_pytorch_model-00003-of-00003.safetensors\n76d788eb0d42c59cc8f8ec007db639aa  transformer/diffusion_pytorch_model.safetensors.index.json\nba9e2980c8630b4abccc643bc9f4a542  vae/config.json\n6f83de55cb720c7fae051b14528577bf  vae/diffusion_pytorch_model.safetensors\n"
  },
  {
    "path": "src/config/model.py",
    "content": "\"\"\"Model configuration constants for Z-Image.\"\"\"\n\nADALN_EMBED_DIM = 256\nSEQ_MULTI_OF = 32\n\nROPE_THETA = 256.0\nROPE_AXES_DIMS = [32, 48, 48]\nROPE_AXES_LENS = [1536, 512, 512]\n\nFREQUENCY_EMBEDDING_SIZE = 256\nMAX_PERIOD = 10000\n\nBASE_IMAGE_SEQ_LEN = 256\nMAX_IMAGE_SEQ_LEN = 4096\nBASE_SHIFT = 0.5\nMAX_SHIFT = 1.15\n\nDEFAULT_VAE_SCALE_FACTOR = 8\nDEFAULT_VAE_IN_CHANNELS = 3\nDEFAULT_VAE_OUT_CHANNELS = 3\nDEFAULT_VAE_LATENT_CHANNELS = 4\nDEFAULT_VAE_NORM_NUM_GROUPS = 32\nDEFAULT_VAE_SCALING_FACTOR = 0.18215\n\nDEFAULT_TRANSFORMER_PATCH_SIZE = (2,)\nDEFAULT_TRANSFORMER_F_PATCH_SIZE = (1,)\nDEFAULT_TRANSFORMER_IN_CHANNELS = 16\nDEFAULT_TRANSFORMER_DIM = 3840\nDEFAULT_TRANSFORMER_N_LAYERS = 30\nDEFAULT_TRANSFORMER_N_REFINER_LAYERS = 2\nDEFAULT_TRANSFORMER_N_HEADS = 30\nDEFAULT_TRANSFORMER_N_KV_HEADS = 30\nDEFAULT_TRANSFORMER_NORM_EPS = 1e-5\nDEFAULT_TRANSFORMER_QK_NORM = True\nDEFAULT_TRANSFORMER_CAP_FEAT_DIM = 2560\nDEFAULT_TRANSFORMER_T_SCALE = 1000.0\n\nDEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS = 1000\nDEFAULT_SCHEDULER_SHIFT = 3.0\nDEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING = False\n\nDEFAULT_LOAD_DEVICE = \"cuda\"\nDEFAULT_LOAD_DTYPE_STR = \"bfloat16\"\n\nBYTES_PER_GB = 2**30\n"
  },
  {
    "path": "src/tools/__init__.py",
    "content": "\"\"\"Tools for Z-Image model management.\"\"\"\n\nfrom .generate_manifest import compute_md5, get_essential_files\n\n__all__ = [\n    \"compute_md5\",\n    \"get_essential_files\",\n]\n\n"
  },
  {
    "path": "src/tools/generate_manifest.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Generate manifest file with MD5 checksums for model weights.\n\nUsage:\n    python -m tools.generate_manifest ckpts/Z-Image-Turbo\n    python -m tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums  # Only list files\n\"\"\"\n\nimport argparse\nimport hashlib\nfrom pathlib import Path\nfrom typing import List\n\n\ndef compute_md5(file_path: Path, chunk_size: int = 8192) -> str:\n    \"\"\"Compute MD5 hash of a file.\"\"\"\n    md5_hash = hashlib.md5()\n    with open(file_path, \"rb\") as f:\n        while chunk := f.read(chunk_size):\n            md5_hash.update(chunk)\n    return md5_hash.hexdigest()\n\n\ndef get_essential_files(model_dir: Path) -> List[Path]:\n    \"\"\"Get list of essential model files.\"\"\"\n    essential_patterns = [\n        \"model_index.json\",\n        \"transformer/config.json\",\n        \"transformer/*.safetensors*\",\n        \"vae/config.json\",\n        \"vae/*.safetensors\",\n        \"text_encoder/config.json\",\n        \"text_encoder/*.safetensors*\",\n        \"tokenizer/tokenizer.json\",\n        \"tokenizer/tokenizer_config.json\",\n        \"scheduler/scheduler_config.json\",\n    ]\n    \n    files = []\n    for pattern in essential_patterns:\n        if \"*\" in pattern:\n            files.extend(model_dir.glob(pattern))\n        else:\n            file_path = model_dir / pattern\n            if file_path.exists():\n                files.append(file_path)\n    \n    return sorted(files)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Generate manifest file for model weights\")\n    parser.add_argument(\"model_dir\", type=str, help=\"Path to model directory\")\n    parser.add_argument(\"--output\", \"-o\", type=str, default=None,\n                       help=\"Output manifest file path (default: auto-detect to config/manifests/)\")\n    parser.add_argument(\"--no-checksums\", action=\"store_true\",\n                       help=\"Only list files without computing checksums\")\n    parser.add_argument(\"--verbose\", \"-v\", action=\"store_true\",\n                       help=\"Print progress\")\n    \n    args = parser.parse_args()\n    \n    model_dir = Path(args.model_dir)\n    if not model_dir.exists():\n        print(f\"Error: Model directory not found: {model_dir}\")\n        return 1\n    \n    # Determine output path\n    if args.output:\n        output_file = Path(args.output)\n    else:\n        # Auto-detect: save to config/manifests/{model-name}.txt\n        model_name = model_dir.name.lower()  # e.g., \"Z-Image-Turbo\" -> \"z-image-turbo\"\n        script_dir = Path(__file__).parent\n        config_dir = script_dir.parent / \"config\" / \"manifests\"\n        config_dir.mkdir(parents=True, exist_ok=True)\n        output_file = config_dir / f\"{model_name}.txt\"\n    \n    # Get essential files\n    files = get_essential_files(model_dir)\n    \n    if not files:\n        print(f\"Warning: No essential files found in {model_dir}\")\n        return 1\n    \n    print(f\"Found {len(files)} essential files\")\n    \n    # Generate manifest\n    with open(output_file, \"w\", encoding=\"utf-8\") as f:\n        f.write(\"# Z-Image Model Manifest\\n\")\n        if args.no_checksums:\n            f.write(\"# Format: <filepath>\\n\")\n        else:\n            f.write(\"# Format: <md5hash>  <filepath>\\n\")\n        f.write(\"# Generated automatically - DO NOT edit manually\\n\\n\")\n        \n        for file_path in files:\n            rel_path = file_path.relative_to(model_dir)\n            \n            if args.no_checksums:\n                f.write(f\"{rel_path}\\n\")\n                if args.verbose:\n                    print(f\"  {rel_path}\")\n            else:\n                if args.verbose:\n                    print(f\"Computing MD5 for {rel_path}...\", end=\" \", flush=True)\n                \n                try:\n                    md5_hash = compute_md5(file_path)\n                    f.write(f\"{md5_hash}  {rel_path}\\n\")\n                    if args.verbose:\n                        print(f\"✓ {md5_hash}\")\n                except Exception as e:\n                    print(f\"✗ Error: {e}\")\n                    continue\n    \n    print(f\"\\n✓ Manifest saved to: {output_file}\")\n    print(f\"  Total files: {len(files)}\")\n    if not args.no_checksums:\n        print(f\"  With MD5 checksums for integrity verification\")\n    \n    return 0\n\n\nif __name__ == \"__main__\":\n    exit(main())\n\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": "\"\"\"Utilities for Z-Image.\"\"\"\n\nfrom .attention import AttentionBackend, dispatch_attention, set_attention_backend\nfrom .helpers import format_bytes, print_memory_stats, ensure_model_weights\nfrom .loader import load_from_local_dir\n\n__all__ = [\n    \"load_from_local_dir\",\n    \"format_bytes\",\n    \"print_memory_stats\",\n    \"ensure_model_weights\",\n    \"AttentionBackend\",\n    \"set_attention_backend\",\n    \"dispatch_attention\",\n]\n"
  },
  {
    "path": "src/utils/attention.py",
    "content": "\"\"\"Attention backend utilities for Z-Image.\"\"\"\n\n# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py\nfrom enum import Enum\nimport functools\nimport inspect\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .import_utils import is_flash_attn_3_available, is_flash_attn_available, is_torch_version\n\n_CAN_USE_FLASH_ATTN_2 = is_flash_attn_available()\n_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()\n\n# MPS Flash Attention (Apple Silicon)\ntry:\n    import mps_flash_attn\n    _CAN_USE_MPS_FLASH = mps_flash_attn.is_available()\nexcept ImportError:\n    _CAN_USE_MPS_FLASH = False\n    mps_flash_attn = None\n_TORCH_VERSION_CHECK = is_torch_version(\">=\", \"2.5.0\")  # have enable_gqa func call in SPDA\n\nif not _TORCH_VERSION_CHECK:\n    raise RuntimeError(\"PyTorch version must be >= 2.5.0 to use this backend.\")\nelse:\n    print(\"PyTorch version is >= 2.5.0, check pass.\")\n\nif _CAN_USE_FLASH_ATTN_2:\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\nelse:\n    flash_attn_func = None\n    flash_attn_varlen_func = None\n\nif _CAN_USE_FLASH_ATTN_3:\n    from flash_attn_interface import (\n        flash_attn_func as flash_attn_3_func,\n        flash_attn_varlen_func as flash_attn_3_varlen_func,\n    )\n\n    _flash_attn_3_sig = inspect.signature(flash_attn_3_func)\n    _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = \"return_attn_probs\" in _flash_attn_3_sig.parameters\nelse:\n    flash_attn_3_func = None\n    flash_attn_3_varlen_func = None\n    _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = False\n\n\nclass AttentionBackend(str, Enum):\n    \"\"\"Supported attention backends.\"\"\"\n\n    # Flash Attention\n    FLASH = \"flash\"\n    FLASH_VARLEN = \"flash_varlen\"\n    FLASH_3 = \"_flash_3\"\n    FLASH_VARLEN_3 = \"_flash_varlen_3\"\n    # MPS Flash Attention (Apple Silicon)\n    MPS_FLASH = \"mps_flash\"\n    # PyTorch Native Backends\n    NATIVE = \"native\"\n    NATIVE_FLASH = \"_native_flash\"\n    NATIVE_MATH = \"_native_math\"\n\n    @classmethod\n    def print_available_backends(cls):\n        available_backends = [backend.value for backend in cls.__members__.values()]\n        print(f\"Available attention backends list: {available_backends}\")\n\n\n# Registry for attention implementations\n_ATTENTION_BACKENDS: Dict[str, Callable] = {}\n_ATTENTION_CONSTRAINTS: Dict[str, List[Callable]] = {}\n\n\ndef register_backend(name: str, constraints: Optional[List[Callable]] = None):\n    def decorator(func):\n        _ATTENTION_BACKENDS[name] = func\n        _ATTENTION_CONSTRAINTS[name] = constraints or []\n        return func\n\n    return decorator\n\n\n# --- Checks ---\ndef _check_device_cuda(query: torch.Tensor, **kwargs) -> None:\n    if query.device.type != \"cuda\":\n        raise ValueError(\"Query must be on a CUDA device.\")\n\n\ndef _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:\n    if query.dtype not in (torch.bfloat16, torch.float16):\n        raise ValueError(\"Query must be either bfloat16 or float16.\")\n\n\ndef _check_device_mps(query: torch.Tensor, **kwargs) -> None:\n    if query.device.type != \"mps\":\n        raise ValueError(\"Query must be on MPS device.\")\n\n\ndef _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):\n    if attn_mask is None:\n        return None\n\n    if attn_mask.ndim == 2:\n        attn_mask = attn_mask[:, None, None, :]\n\n    # Convert bool mask to float additive mask\n    if attn_mask.dtype == torch.bool:\n        # NOTE: We skip checking for all-True mask (torch.all) to avoid graph breaks in torch.compile\n        new_mask = torch.zeros_like(attn_mask, dtype=dtype)\n        new_mask.masked_fill_(~attn_mask, float(\"-inf\"))\n        return new_mask\n\n    return attn_mask\n\n\ndef _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:\n    \"\"\"Normalize an attention mask to shape [batch_size, seq_len_k] (bool).\"\"\"\n    if attn_mask.dtype != torch.bool:\n        # Try to convert float mask back to bool if possible, or assume it's float mask\n        # For varlen flash attn, we strictly need bool mask indicating valid tokens\n        if torch.is_floating_point(attn_mask):\n            return attn_mask > -1  # Assuming -inf is masked\n        # raise ValueError(f\"Attention mask must be of type bool, got {attn_mask.dtype}.\")\n\n    if attn_mask.ndim == 1:\n        attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)\n    elif attn_mask.ndim == 2:\n        if attn_mask.size(0) not in [1, batch_size]:\n            attn_mask = attn_mask.expand(batch_size, seq_len_k)\n    elif attn_mask.ndim == 3:\n        attn_mask = attn_mask.any(dim=1)\n        attn_mask = attn_mask.expand(batch_size, seq_len_k)\n    elif attn_mask.ndim == 4:\n        attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k)\n        attn_mask = attn_mask.any(dim=(1, 2))\n\n    if attn_mask.shape != (batch_size, seq_len_k):\n        # Fallback reshape\n        return attn_mask.view(batch_size, seq_len_k)\n\n    return attn_mask\n\n\n@functools.lru_cache(maxsize=128)\ndef _prepare_for_flash_attn_varlen_without_mask(\n    batch_size: int,\n    seq_len_q: int,\n    seq_len_kv: int,\n    device: Optional[torch.device] = None,\n):\n    # Optimized to avoid Inductor \"pointless_cumsum_replacement\" crash and remove graph breaks\n    seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)\n    seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)\n\n    cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q\n    cu_seqlens_k = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_kv\n\n    return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (seq_len_q, seq_len_kv)\n\n\ndef _prepare_for_flash_attn_varlen_with_mask(\n    batch_size: int,\n    seq_len_q: int,\n    attn_mask: torch.Tensor,\n    device: Optional[torch.device] = None,\n):\n    seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)\n    seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)\n    # Use arange for Q to avoid Inductor crash\n    cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q\n\n    cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)\n    cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)\n\n    max_seqlen_q = seq_len_q\n    max_seqlen_k = attn_mask.shape[1]  # not max().item(), static shape to avoid graph break\n\n    return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)\n\n\ndef _prepare_for_flash_attn_varlen(\n    batch_size: int,\n    seq_len_q: int,\n    seq_len_kv: int,\n    attn_mask: Optional[torch.Tensor] = None,\n    device: Optional[torch.device] = None,\n) -> None:\n    if attn_mask is None:\n        return _prepare_for_flash_attn_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)\n    return _prepare_for_flash_attn_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)\n\n\n@register_backend(AttentionBackend.FLASH, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])\ndef _flash_attention(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    if not _CAN_USE_FLASH_ATTN_2:\n        raise RuntimeError(\n            f\"Flash Attention backend '{AttentionBackend.FLASH}' is not usable because of missing package.\"\n        )\n\n    out = flash_attn_func(\n        q=query,\n        k=key,\n        v=value,\n        dropout_p=dropout_p,\n        softmax_scale=scale,\n        causal=is_causal,\n    )\n    return out\n\n\n@register_backend(AttentionBackend.FLASH_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])\ndef _flash_varlen_attention(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    if not _CAN_USE_FLASH_ATTN_2:\n        raise RuntimeError(f\"Backend '{AttentionBackend.FLASH_VARLEN}' requires flash-attn.\")\n\n    batch_size, seq_len_q, _, _ = query.shape\n    _, seq_len_kv, _, _ = key.shape\n\n    if attn_mask is not None:\n        attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)\n\n    (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(\n        batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device\n    )\n\n    query_packed = query.flatten(0, 1)\n\n    if attn_mask is not None:\n        key_valid = []\n        value_valid = []\n        for b in range(batch_size):\n            valid_len = seqlens_k[b]\n            key_valid.append(key[b, :valid_len])\n            value_valid.append(value[b, :valid_len])\n        key_packed = torch.cat(key_valid, dim=0)\n        value_packed = torch.cat(value_valid, dim=0)\n    else:\n        key_packed = key.flatten(0, 1)\n        value_packed = value.flatten(0, 1)\n\n    out = flash_attn_varlen_func(\n        q=query_packed,\n        k=key_packed,\n        v=value_packed,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n        max_seqlen_q=max_seqlen_q,\n        max_seqlen_k=max_seqlen_k,\n        dropout_p=dropout_p,\n        softmax_scale=scale,\n        causal=is_causal,\n    )\n    out = out.unflatten(0, (batch_size, -1))\n    return out\n\n\n@register_backend(AttentionBackend.FLASH_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])\ndef _flash_attention_3(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,  # Unused in simple FA3 func\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    if not _CAN_USE_FLASH_ATTN_3:\n        raise RuntimeError(f\"Backend '{AttentionBackend.FLASH_3}' requires Flash Attention 3 beta.\")\n\n    kwargs = {\n        \"q\": query,\n        \"k\": key,\n        \"v\": value,\n        \"softmax_scale\": scale,\n        \"causal\": is_causal,\n    }\n\n    if _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS:\n        kwargs[\"return_attn_probs\"] = False\n\n    out = flash_attn_3_func(**kwargs)\n\n    if isinstance(out, tuple):\n        out = out[0]\n\n    return out\n\n\n@register_backend(AttentionBackend.FLASH_VARLEN_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])\ndef _flash_varlen_attention_3(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    if not _CAN_USE_FLASH_ATTN_3:\n        raise RuntimeError(f\"Backend '{AttentionBackend.FLASH_VARLEN_3}' requires Flash Attention 3 beta.\")\n\n    batch_size, seq_len_q, _, _ = query.shape\n    _, seq_len_kv, _, _ = key.shape\n\n    if attn_mask is not None:\n        attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)\n\n    (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(\n        batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device\n    )\n\n    query_packed = query.flatten(0, 1)\n\n    if attn_mask is not None:\n        key_valid = []\n        value_valid = []\n        for b in range(batch_size):\n            valid_len = seqlens_k[b]\n            key_valid.append(key[b, :valid_len])\n            value_valid.append(value[b, :valid_len])\n        key_packed = torch.cat(key_valid, dim=0)\n        value_packed = torch.cat(value_valid, dim=0)\n    else:\n        key_packed = key.flatten(0, 1)\n        value_packed = value.flatten(0, 1)\n\n    kwargs = {\n        \"q\": query_packed,\n        \"k\": key_packed,\n        \"v\": value_packed,\n        \"cu_seqlens_q\": cu_seqlens_q,\n        \"cu_seqlens_k\": cu_seqlens_k,\n        \"max_seqlen_q\": max_seqlen_q,\n        \"max_seqlen_k\": max_seqlen_k,\n        \"softmax_scale\": scale,\n        \"causal\": is_causal,\n    }\n\n    supports_return_probs = \"return_attn_probs\" in inspect.signature(flash_attn_3_varlen_func).parameters\n\n    if supports_return_probs:\n        kwargs[\"return_attn_probs\"] = False\n\n    out = flash_attn_3_varlen_func(**kwargs)\n\n    if isinstance(out, tuple):\n        out = out[0]\n\n    out = out.unflatten(0, (batch_size, -1))\n    return out\n\n\n@register_backend(AttentionBackend.MPS_FLASH, constraints=[_check_device_mps, _check_qkv_dtype_bf16_or_fp16])\ndef _mps_flash_attention(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    \"\"\"MPS Flash Attention for Apple Silicon (M1/M2/M3/M4).\"\"\"\n    if not _CAN_USE_MPS_FLASH:\n        raise RuntimeError(\n            f\"MPS Flash Attention backend '{AttentionBackend.MPS_FLASH}' requires mps-flash-attn package. \"\n            \"Install with: pip install mps-flash-attn\"\n        )\n\n    # Convert from (B, S, H, D) to (B, H, S, D) for mps-flash-attn\n    query = query.transpose(1, 2)\n    key = key.transpose(1, 2)\n    value = value.transpose(1, 2)\n\n    # Convert mask to MFA format (bool, True = masked)\n    mfa_mask = None\n    if attn_mask is not None:\n        mfa_mask = mps_flash_attn.convert_mask(_process_mask(attn_mask, query.dtype))\n\n    out = mps_flash_attn.flash_attention(\n        query, key, value,\n        is_causal=is_causal,\n        scale=scale,\n        attn_mask=mfa_mask,\n    )\n\n    # Convert back to (B, S, H, D)\n    return out.transpose(1, 2).contiguous()\n\n\ndef _native_attention_wrapper(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n    backend_kernel=None,\n) -> torch.Tensor:\n\n    query = query.transpose(1, 2)\n    key = key.transpose(1, 2)\n    value = value.transpose(1, 2)\n    attn_mask = _process_mask(attn_mask, query.dtype)\n\n    if backend_kernel is not None:\n        with torch.nn.attention.sdpa_kernel(backend_kernel):\n            out = F.scaled_dot_product_attention(\n                query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale\n            )\n    else:\n        out = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale\n        )\n\n    return out.transpose(1, 2).contiguous()\n\n\n@register_backend(AttentionBackend.NATIVE_FLASH)\ndef _native_flash_attention(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n) -> torch.Tensor:\n    return _native_attention_wrapper(\n        query,\n        key,\n        value,\n        attn_mask=None,\n        dropout_p=dropout_p,\n        is_causal=is_causal,\n        scale=scale,\n        backend_kernel=torch.nn.attention.SDPBackend.FLASH_ATTENTION,\n    )\n\n\n@register_backend(AttentionBackend.NATIVE_MATH)\ndef _math_attention(*args, **kwargs):\n    return _native_attention_wrapper(*args, **kwargs, backend_kernel=torch.nn.attention.SDPBackend.MATH)\n\n\n@register_backend(AttentionBackend.NATIVE)\ndef _native_attention(*args, **kwargs):\n    return _native_attention_wrapper(*args, **kwargs, backend_kernel=None)\n\n\ndef dispatch_attention(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_mask: Optional[torch.Tensor] = None,\n    dropout_p: float = 0.0,\n    is_causal: bool = False,\n    scale: Optional[float] = None,\n    backend: Union[str, AttentionBackend, None] = None,\n) -> torch.Tensor:\n\n    if isinstance(backend, AttentionBackend):\n        backend = backend.value\n    elif backend is None:\n        backend = AttentionBackend.NATIVE\n    else:\n        backend = str(backend)\n\n    # Explicit dispatch to avoid dynamo guard issues on global dict\n    if backend == AttentionBackend.FLASH:\n        return _flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.FLASH_VARLEN:\n        return _flash_varlen_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.FLASH_3:\n        return _flash_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.FLASH_VARLEN_3:\n        return _flash_varlen_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.MPS_FLASH:\n        return _mps_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.NATIVE_FLASH:\n        return _native_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    elif backend == AttentionBackend.NATIVE_MATH:\n        return _math_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n    else:\n        return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)\n\n\ndef set_attention_backend(backend: Union[str, AttentionBackend, None]):\n    try:\n        from zimage.transformer import ZImageAttention\n\n        if backend is not None:\n            backend = str(backend)\n        ZImageAttention._attention_backend = backend\n    except ImportError:\n        pass\n"
  },
  {
    "path": "src/utils/helpers.py",
    "content": "\"\"\"Helper utilities for Z-Image.\"\"\"\n\nimport hashlib\nimport json\nfrom pathlib import Path\nfrom typing import Optional, List, Tuple, Dict\n\nfrom loguru import logger\nimport torch\n\nfrom config import BYTES_PER_GB\n\n\ndef format_bytes(size: float) -> str:\n    \"\"\"\n    Format bytes to GB string.\n\n    Args:\n        size: Size in bytes\n\n    Returns:\n        Formatted string in GB\n    \"\"\"\n    n = size / BYTES_PER_GB\n    return f\"{n:.2f} GB\"\n\n\ndef print_memory_stats(stage: str) -> None:\n    \"\"\"\n    Print CUDA memory statistics.\n\n    Args:\n        stage: Description of current stage\n    \"\"\"\n    if not torch.cuda.is_available():\n        logger.warning(\"CUDA not available, skipping memory stats\")\n        return\n\n    torch.cuda.synchronize()\n    allocated = torch.cuda.max_memory_allocated()\n    reserved = torch.cuda.max_memory_reserved()\n    current_allocated = torch.cuda.memory_allocated()\n    current_reserved = torch.cuda.memory_reserved()\n\n    logger.info(f\"[{stage}] Memory Stats:\")\n    logger.info(f\"  Current Allocated: {format_bytes(current_allocated)}\")\n    logger.info(f\"  Current Reserved:  {format_bytes(current_reserved)}\")\n    logger.info(f\"  Peak Allocated:    {format_bytes(allocated)}\")\n    logger.info(f\"  Peak Reserved:     {format_bytes(reserved)}\")\n\n\ndef compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:\n    \"\"\"Compute MD5 hash of a file.\"\"\"\n    md5_hash = hashlib.md5()\n    with open(file_path, \"rb\") as f:\n        while chunk := f.read(chunk_size):\n            md5_hash.update(chunk)\n    return md5_hash.hexdigest()\n\n\ndef load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:\n    \"\"\"Load manifest file. Returns dict mapping file paths to MD5 hashes (or None).\"\"\"\n    manifest = {}\n    if not manifest_file.exists():\n        return manifest\n    \n    with open(manifest_file, \"r\", encoding=\"utf-8\") as f:\n        for line_num, line in enumerate(f, 1):\n            line = line.strip()\n            # Skip empty lines and comments\n            if not line or line.startswith(\"#\"):\n                continue\n            \n            parts = line.split()\n            \n            if len(parts) == 1:\n                # Only file path, no checksum\n                file_path = parts[0]\n                manifest[file_path] = None\n            elif len(parts) == 2:\n                # File path with checksum\n                if len(parts[0]) == 32 and all(c in '0123456789abcdef' for c in parts[0].lower()):\n                    md5_hash, file_path = parts\n                else:\n                    file_path, md5_hash = parts\n                manifest[file_path] = md5_hash\n            else:\n                logger.warning(f\"Invalid manifest format at line {line_num}: {line}\")\n                continue\n    \n    return manifest\n\n\ndef verify_file_integrity(\n    base_dir: Path, \n    manifest: Dict[str, Optional[str]],\n    verify_checksums: bool = True\n) -> Tuple[bool, List[str], List[str]]:\n    \"\"\"\n    Verify file integrity using a manifest.\n    \n    Args:\n        base_dir: Base directory for relative file paths\n        manifest: Dictionary of relative paths to MD5 hashes (None if no hash provided)\n        verify_checksums: If True, verify MD5 checksums when available; if False, only check existence\n        \n    Returns:\n        Tuple of (all_valid: bool, missing_files: List[str], corrupted_files: List[str])\n    \"\"\"\n    missing = []\n    corrupted = []\n    \n    for rel_path, expected_md5 in manifest.items():\n        file_path = base_dir / rel_path\n        \n        if not file_path.exists():\n            missing.append(rel_path)\n            continue\n        \n        # Only verify checksum if requested AND hash is available\n        if verify_checksums and expected_md5 is not None:\n            try:\n                actual_md5 = compute_file_md5(file_path)\n                if actual_md5 != expected_md5:\n                    corrupted.append(rel_path)\n                    logger.debug(f\"Checksum mismatch for {rel_path}: expected {expected_md5}, got {actual_md5}\")\n            except Exception as e:\n                logger.error(f\"Failed to compute checksum for {rel_path}: {e}\")\n                corrupted.append(rel_path)\n    \n    all_valid = len(missing) == 0 and len(corrupted) == 0\n    return all_valid, missing, corrupted\n\n\ndef ensure_model_weights(\n    model_path: str, \n    repo_id: str = \"Tongyi-MAI/Z-Image-Turbo\",\n    verify: bool = False,\n    manifest_name: Optional[str] = None\n) -> Path:\n    \"\"\"\n    Ensure model weights exist and optionally verify integrity.\n    \n    Args:\n        model_path: Path to model directory\n        repo_id: HuggingFace repo ID for download\n        verify: If True, verify MD5 checksums; if False, only check existence\n        manifest_name: Manifest file name in src/config/manifests/ (auto-detect if None)\n        \n    Returns:\n        Path to validated model directory\n    \"\"\"\n    from huggingface_hub import snapshot_download\n    \n    target_dir = Path(model_path)\n    \n    # Determine manifest path\n    if manifest_name:\n        # Explicitly specified manifest from config/manifests/\n        manifest_path = Path(__file__).parent.parent / \"config\" / \"manifests\" / manifest_name\n    else:\n        # Auto-detect\n        model_name = target_dir.name.lower()  # e.g., \"Z-Image-Turbo\" -> \"z-image-turbo\"\n        config_manifest = Path(__file__).parent.parent / \"config\" / \"manifests\" / f\"{model_name}.txt\"\n        \n        if config_manifest.exists():\n            manifest_path = config_manifest\n        else:\n            # Fallback\n            manifest_path = target_dir / \"manifest.txt\"\n    \n    manifest = load_manifest(manifest_path)\n    \n    if not manifest:\n        logger.warning(f\"Manifest file not found: {manifest_path}\")\n        logger.warning(\"Skipping file verification (assuming model exists)\")\n        if target_dir.exists():\n            logger.info(f\"✓ Model directory exists: {target_dir}\")\n            return target_dir\n        else:\n            logger.warning(f\"Model directory not found: {target_dir}\")\n            missing_files = [\"entire model directory\"]\n            corrupted_files = []\n    else:\n        # Count files with checksums\n        files_with_checksums = sum(1 for v in manifest.values() if v is not None)\n        \n        if verify and files_with_checksums == 0:\n            logger.info(f\"Verify requested but no checksums in manifest, only checking existence\")\n        elif verify and files_with_checksums > 0:\n            logger.info(f\"Verifying {files_with_checksums} file(s) with MD5 checksums...\")\n        \n        # Verify files\n        all_valid, missing_files, corrupted_files = verify_file_integrity(\n            target_dir, manifest, verify_checksums=verify\n        )\n        \n        if all_valid:\n            if verify and files_with_checksums > 0:\n                logger.success(f\"✓ All files verified with MD5 checksums in {target_dir}\")\n            else:\n                logger.info(f\"✓ All {len(manifest)} required files exist in {target_dir}\")\n            return target_dir\n    \n    # Report missing and corrupted files\n    if missing_files:\n        logger.warning(f\"Missing {len(missing_files)} file(s):\")\n        for f in missing_files[:10]:\n            logger.warning(f\"  - {f}\")\n        if len(missing_files) > 10:\n            logger.warning(f\"  ... and {len(missing_files) - 10} more\")\n    \n    if corrupted_files:\n        logger.error(f\"Corrupted {len(corrupted_files)} file(s) (checksum mismatch):\")\n        for f in corrupted_files[:10]:\n            logger.error(f\"  - {f}\")\n        if len(corrupted_files) > 10:\n            logger.error(f\"  ... and {len(corrupted_files) - 10} more\")\n    \n    # Download model weights\n    logger.info(f\"\\nAttempting to download from {repo_id}...\")\n    try:\n        target_dir.mkdir(parents=True, exist_ok=True)\n        snapshot_download(\n            repo_id=repo_id,\n            local_dir=str(target_dir),\n            local_dir_use_symlinks=False,\n            resume_download=True,\n        )\n        logger.success(\"✓ Download completed\")\n    except Exception as e:\n        logger.error(f\"✗ Download failed: {e}\")\n        logger.info(\n            f\"\\nIf you are offline, please manually download from:\\n\"\n            f\"  https://huggingface.co/{repo_id}\\n\"\n            f\"and place in: {target_dir.absolute()}\"\n        )\n        raise RuntimeError(f\"Failed to download model weights: {e}\")\n    \n    # Verify after download\n    if manifest:\n        all_valid, missing_after, corrupted_after = verify_file_integrity(\n            target_dir, manifest, verify_checksums=verify\n        )\n        \n        if not all_valid:\n            error_msg = []\n            if missing_after:\n                error_msg.append(f\"Still missing {len(missing_after)} file(s)\")\n            if corrupted_after:\n                error_msg.append(f\"Still corrupted {len(corrupted_after)} file(s)\")\n            \n            raise FileNotFoundError(\n                f\"After download: {', '.join(error_msg)}\\n\"\n                f\"Please verify the download or manually place files in:\\n\"\n                f\"  {target_dir.absolute()}\"\n            )\n    \n    logger.success(\"✓ All model weights validated successfully\")\n    return target_dir\n"
  },
  {
    "path": "src/utils/import_utils.py",
    "content": "import importlib.util\n\nimport torch\n\n\ndef is_flash_attn_available():\n    return importlib.util.find_spec(\"flash_attn\") is not None\n\n\ndef is_flash_attn_3_available():\n    return importlib.util.find_spec(\"flash_attn_interface\") is not None\n\n\ndef is_torch_version(operator: str, version: str):\n    from packaging import version as pversion\n\n    torch_version = pversion.parse(torch.__version__)\n    target_version = pversion.parse(version)\n\n    # print(f\"torch_version: {torch_version}, target: torch{operator}{target_version}\")\n    if operator == \">\":\n        return torch_version > target_version\n    elif operator == \">=\":\n        return torch_version >= target_version\n    elif operator == \"==\":\n        return torch_version == target_version\n    elif operator == \"<=\":\n        return torch_version <= target_version\n    elif operator == \"<\":\n        return torch_version < target_version\n    return False\n"
  },
  {
    "path": "src/utils/loader.py",
    "content": "\"\"\"Model loading utilities for Z-Image components.\"\"\"\n\nimport json\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import Optional, Union\n\nfrom loguru import logger\nfrom safetensors.torch import load_file\nimport torch\nfrom transformers import AutoModel, AutoTokenizer\n\nfrom config import (\n    DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,\n    DEFAULT_SCHEDULER_SHIFT,\n    DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,\n    DEFAULT_TRANSFORMER_CAP_FEAT_DIM,\n    DEFAULT_TRANSFORMER_DIM,\n    DEFAULT_TRANSFORMER_F_PATCH_SIZE,\n    DEFAULT_TRANSFORMER_IN_CHANNELS,\n    DEFAULT_TRANSFORMER_N_HEADS,\n    DEFAULT_TRANSFORMER_N_KV_HEADS,\n    DEFAULT_TRANSFORMER_N_LAYERS,\n    DEFAULT_TRANSFORMER_N_REFINER_LAYERS,\n    DEFAULT_TRANSFORMER_NORM_EPS,\n    DEFAULT_TRANSFORMER_PATCH_SIZE,\n    DEFAULT_TRANSFORMER_QK_NORM,\n    DEFAULT_TRANSFORMER_T_SCALE,\n    DEFAULT_VAE_IN_CHANNELS,\n    DEFAULT_VAE_LATENT_CHANNELS,\n    DEFAULT_VAE_NORM_NUM_GROUPS,\n    DEFAULT_VAE_OUT_CHANNELS,\n    DEFAULT_VAE_SCALING_FACTOR,\n    ROPE_AXES_DIMS,\n    ROPE_AXES_LENS,\n    ROPE_THETA,\n)\nfrom zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL\nfrom zimage.scheduler import FlowMatchEulerDiscreteScheduler\n\nDIFFUSERS_AVAILABLE = False\n\n\ndef load_config(config_path: str) -> dict:\n    with open(config_path, \"r\") as f:\n        return json.load(f)\n\n\ndef load_sharded_safetensors(weight_dir: Path, device: str = \"cuda\", dtype: Optional[torch.dtype] = None) -> dict:\n    \"\"\"Load sharded safetensors from a directory.\"\"\"\n    weight_dir = Path(weight_dir)\n    index_files = list(weight_dir.glob(\"*.safetensors.index.json\"))\n\n    state_dict = {}\n    if index_files:\n        # Load sharded weights\n        with open(index_files[0], \"r\") as f:\n            index = json.load(f)\n        weight_map = index.get(\"weight_map\", {})\n        shard_files = set(weight_map.values())\n        for shard_file in shard_files:\n            shard_path = weight_dir / shard_file\n            shard_state = load_file(str(shard_path), device=str(device))\n            state_dict.update(shard_state)\n    else:\n        # Load single safetensors file\n        safetensors_files = list(weight_dir.glob(\"*.safetensors\"))\n        if not safetensors_files:\n            raise FileNotFoundError(f\"No safetensors files found in {weight_dir}\")\n        state_dict = load_file(str(safetensors_files[0]), device=str(device))\n\n    # Cast to target dtype if specified\n    if dtype is not None:\n        state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}\n\n    return state_dict\n\n\ndef load_from_local_dir(\n    model_dir: Union[str, Path],\n    device: str = \"cuda\",\n    dtype: torch.dtype = torch.bfloat16,\n    verbose: bool = False,\n    compile: bool = False,\n) -> dict:\n    \"\"\"\n    Load all Z-Image components from local directory.\n\n    Args:\n        model_dir: Path to model directory\n        device: Device to load models on\n        dtype: Data type for model weights\n        verbose: Whether to display loading logs\n        compile: Whether to compile transformer and vae with torch.compile\n\n    Returns:\n        Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler\n    \"\"\"\n    model_dir = Path(model_dir)\n\n    sys.path.insert(0, str(model_dir.parent.parent / \"Z-Image\" / \"src\"))\n    from zimage.transformer import ZImageTransformer2DModel\n\n    if verbose:\n        logger.info(f\"Loading Z-Image from: {model_dir}\")\n\n    # DiT\n    if verbose:\n        logger.info(\"Loading DiT...\")\n    transformer_dir = model_dir / \"transformer\"\n    config = load_config(str(transformer_dir / \"config.json\"))\n\n    with torch.device(\"meta\"):\n        transformer = ZImageTransformer2DModel(\n            all_patch_size=tuple(config.get(\"all_patch_size\", DEFAULT_TRANSFORMER_PATCH_SIZE)),\n            all_f_patch_size=tuple(config.get(\"all_f_patch_size\", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),\n            in_channels=config.get(\"in_channels\", DEFAULT_TRANSFORMER_IN_CHANNELS),\n            dim=config.get(\"dim\", DEFAULT_TRANSFORMER_DIM),\n            n_layers=config.get(\"n_layers\", DEFAULT_TRANSFORMER_N_LAYERS),\n            n_refiner_layers=config.get(\"n_refiner_layers\", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),\n            n_heads=config.get(\"n_heads\", DEFAULT_TRANSFORMER_N_HEADS),\n            n_kv_heads=config.get(\"n_kv_heads\", DEFAULT_TRANSFORMER_N_KV_HEADS),\n            norm_eps=config.get(\"norm_eps\", DEFAULT_TRANSFORMER_NORM_EPS),\n            qk_norm=config.get(\"qk_norm\", DEFAULT_TRANSFORMER_QK_NORM),\n            cap_feat_dim=config.get(\"cap_feat_dim\", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),\n            rope_theta=config.get(\"rope_theta\", ROPE_THETA),\n            t_scale=config.get(\"t_scale\", DEFAULT_TRANSFORMER_T_SCALE),\n            axes_dims=config.get(\"axes_dims\", ROPE_AXES_DIMS),\n            axes_lens=config.get(\"axes_lens\", ROPE_AXES_LENS),\n        ).to(dtype)\n\n    # DiT (weights to CPU then move to GPU to optimize memory)\n    state_dict = load_sharded_safetensors(transformer_dir, device=\"cpu\", dtype=dtype)\n    transformer.load_state_dict(state_dict, strict=False, assign=True)\n    del state_dict\n\n    if verbose:\n        logger.info(\"Moving DiT to GPU...\")\n    transformer = transformer.to(device)\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    transformer.eval()\n\n    # VAE\n    if verbose:\n        logger.info(\"Loading VAE...\")\n    vae_dir = model_dir / \"vae\"\n    vae_config = load_config(str(vae_dir / \"config.json\"))\n\n    vae = LocalAutoencoderKL(\n        in_channels=vae_config.get(\"in_channels\", DEFAULT_VAE_IN_CHANNELS),\n        out_channels=vae_config.get(\"out_channels\", DEFAULT_VAE_OUT_CHANNELS),\n        down_block_types=tuple(vae_config.get(\"down_block_types\", (\"DownEncoderBlock2D\",))),\n        up_block_types=tuple(vae_config.get(\"up_block_types\", (\"UpDecoderBlock2D\",))),\n        block_out_channels=tuple(vae_config.get(\"block_out_channels\", (64,))),\n        layers_per_block=vae_config.get(\"layers_per_block\", 1),\n        latent_channels=vae_config.get(\"latent_channels\", DEFAULT_VAE_LATENT_CHANNELS),\n        norm_num_groups=vae_config.get(\"norm_num_groups\", DEFAULT_VAE_NORM_NUM_GROUPS),\n        scaling_factor=vae_config.get(\"scaling_factor\", DEFAULT_VAE_SCALING_FACTOR),\n        shift_factor=vae_config.get(\"shift_factor\", None),\n        use_quant_conv=vae_config.get(\"use_quant_conv\", True),\n        use_post_quant_conv=vae_config.get(\"use_post_quant_conv\", True),\n        mid_block_add_attention=vae_config.get(\"mid_block_add_attention\", True),\n    )\n\n    # VAE (fp32 for better precision)\n    vae_state_dict = load_sharded_safetensors(vae_dir, device=\"cpu\")\n    vae.load_state_dict(vae_state_dict, strict=False)\n    del vae_state_dict\n    vae.to(device=device, dtype=torch.float32)\n    vae.eval()\n    torch.cuda.empty_cache()\n\n    # Text Encoder\n    if verbose:\n        logger.info(\"Loading Text Encoder...\")\n    text_encoder_dir = model_dir / \"text_encoder\"\n    text_encoder = AutoModel.from_pretrained(\n        str(text_encoder_dir),\n        # torch_dtype=dtype, # some version use this\n        dtype=dtype,\n        trust_remote_code=True,\n    )\n    text_encoder.to(device)\n    text_encoder.eval()\n\n    # Tokenizer\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    if verbose:\n        logger.info(\"Loading Tokenizer...\")\n    tokenizer_dir = model_dir / \"tokenizer\"\n    tokenizer = AutoTokenizer.from_pretrained(\n        str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),\n        trust_remote_code=True,\n    )\n\n    # Scheduler\n    if verbose:\n        logger.info(\"Loading Scheduler...\")\n    scheduler_dir = model_dir / \"scheduler\"\n    scheduler_config = load_config(str(scheduler_dir / \"scheduler_config.json\"))\n    scheduler = FlowMatchEulerDiscreteScheduler(\n        num_train_timesteps=scheduler_config.get(\"num_train_timesteps\", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),\n        shift=scheduler_config.get(\"shift\", DEFAULT_SCHEDULER_SHIFT),\n        use_dynamic_shifting=scheduler_config.get(\"use_dynamic_shifting\", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),\n    )\n\n    if compile:\n        if verbose:\n            logger.info(\"Compiling DiT and VAE...\")\n        transformer = torch.compile(transformer)\n        vae = torch.compile(vae)\n\n    if verbose:\n        logger.success(\"All components loaded successfully\")\n\n    return {\n        \"transformer\": transformer,\n        \"vae\": vae,\n        \"text_encoder\": text_encoder,\n        \"tokenizer\": tokenizer,\n        \"scheduler\": scheduler,\n    }\n"
  },
  {
    "path": "src/zimage/__init__.py",
    "content": "\"\"\"Z-Image PyTorch Native Implementation.\"\"\"\n\nfrom .pipeline import generate\nfrom .transformer import ZImageTransformer2DModel\n\n__all__ = [\n    \"ZImageTransformer2DModel\",\n    \"generate\",\n]\n"
  },
  {
    "path": "src/zimage/autoencoder.py",
    "content": "\"\"\"AutoencoderKL implementation compatible with diffusers weights.\"\"\"\n\n# Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\n\n@dataclass\nclass AutoencoderKLOutput:\n    sample: torch.Tensor\n\n\nclass AutoencoderConfig:\n    def __init__(self, **kwargs):\n        self.__dict__.update(kwargs)\n\n    def get(self, key, default=None):\n        return self.__dict__.get(key, default)\n\n    def __getattr__(self, name):\n        return self.__dict__.get(name)\n\n\ndef swish(x):\n    return x * torch.sigmoid(x)\n\n\nclass ResnetBlock2D(nn.Module):\n    def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):\n        super().__init__()\n        out_channels = out_channels or in_channels\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        self.nonlinearity = swish\n\n        if self.in_channels != self.out_channels:\n            self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n        else:\n            self.conv_shortcut = None\n\n    def forward(self, input_tensor, temb=None):\n        hidden_states = input_tensor\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.conv1(hidden_states)\n\n        hidden_states = self.norm2(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            input_tensor = self.conv_shortcut(input_tensor)\n\n        output_tensor = (input_tensor + hidden_states) / 1.0\n        return output_tensor\n\n\nclass Attention(nn.Module):\n    def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):\n        super().__init__()\n        self.heads = heads\n        self.in_channels = in_channels\n        self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n\n        self.to_q = nn.Linear(in_channels, in_channels)\n        self.to_k = nn.Linear(in_channels, in_channels)\n        self.to_v = nn.Linear(in_channels, in_channels)\n        self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])\n\n    def forward(self, hidden_states):\n        b, c, h, w = hidden_states.shape\n        residual = hidden_states\n        hidden_states = self.group_norm(hidden_states)\n        hidden_states = hidden_states.view(b, c, -1).transpose(1, 2)  # (B, H*W, C)\n\n        query = self.to_q(hidden_states)\n        key = self.to_k(hidden_states)\n        value = self.to_v(hidden_states)\n\n        import torch.nn.functional as F\n\n        hidden_states = F.scaled_dot_product_attention(query, key, value)\n\n        hidden_states = self.to_out[0](hidden_states)\n        hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)\n\n        return residual + hidden_states\n\n\nclass Downsample2D(nn.Module):\n    def __init__(self, channels, with_conv=True, out_channels=None, padding=1):\n        super().__init__()\n        out_channels = out_channels or channels\n        self.with_conv = with_conv\n        if with_conv:\n            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)\n\n    def forward(self, hidden_states):\n        if self.with_conv:\n            return self.conv(hidden_states)\n        else:\n            return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)\n\n\nclass Upsample2D(nn.Module):\n    def __init__(self, channels, with_conv=True, out_channels=None):\n        super().__init__()\n        out_channels = out_channels or channels\n        self.with_conv = with_conv\n        if with_conv:\n            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n    def forward(self, hidden_states):\n        hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\nclass DownEncoderBlock2D(nn.Module):\n    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):\n        super().__init__()\n        resnets = []\n        for i in range(num_layers):\n            in_c = in_channels if i == 0 else out_channels\n            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states):\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                pad = (0, 1, 0, 1)\n                hidden_states = torch.nn.functional.pad(hidden_states, pad, mode=\"constant\", value=0)\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states\n\n\nclass UpDecoderBlock2D(nn.Module):\n    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):\n        super().__init__()\n        resnets = []\n        for i in range(num_layers):\n            in_c = in_channels if i == 0 else out_channels\n            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n    def forward(self, hidden_states):\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass UNetMidBlock2D(nn.Module):\n    def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):\n        super().__init__()\n        self.resnets = nn.ModuleList(\n            [\n                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),\n                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),\n            ]\n        )\n        self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])\n\n    def forward(self, hidden_states):\n        hidden_states = self.resnets[0](hidden_states)\n        for attn in self.attentions:\n            hidden_states = attn(hidden_states)\n        hidden_states = self.resnets[1](hidden_states)\n        return hidden_states\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=(64,),\n        layers_per_block=2,\n        norm_num_groups=32,\n        double_z=True,\n    ):\n        super().__init__()\n        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)\n\n        self.down_blocks = nn.ModuleList([])\n        output_channel = block_out_channels[0]\n        for i, block_out_channel in enumerate(block_out_channels):\n            input_channel = output_channel\n            output_channel = block_out_channel\n            is_final_block = i == len(block_out_channels) - 1\n\n            block = DownEncoderBlock2D(\n                input_channel,\n                output_channel,\n                num_layers=layers_per_block,\n                resnet_groups=norm_num_groups,\n                add_downsample=not is_final_block,\n            )\n            self.down_blocks.append(block)\n\n        self.mid_block = UNetMidBlock2D(\n            block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n        )\n\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)\n        self.conv_act = nn.SiLU()\n\n        conv_out_channels = 2 * out_channels if double_z else out_channels\n        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        for block in self.down_blocks:\n            x = block(x)\n        x = self.mid_block(x)\n        x = self.conv_norm_out(x)\n        x = self.conv_act(x)\n        x = self.conv_out(x)\n        return x\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=(64,),\n        layers_per_block=2,\n        norm_num_groups=32,\n    ):\n        super().__init__()\n        self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)\n\n        self.mid_block = UNetMidBlock2D(\n            block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n        )\n\n        self.up_blocks = nn.ModuleList([])\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        output_channel = reversed_block_out_channels[0]\n\n        for i, block_out_channel in enumerate(reversed_block_out_channels):\n            input_channel = output_channel\n            output_channel = block_out_channel\n            is_final_block = i == len(block_out_channels) - 1\n            block = UpDecoderBlock2D(\n                input_channel,\n                output_channel,\n                num_layers=layers_per_block + 1,\n                resnet_groups=norm_num_groups,\n                add_upsample=not is_final_block,\n            )\n            self.up_blocks.append(block)\n\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)\n        self.conv_act = nn.SiLU()\n        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        x = self.mid_block(x)\n        for block in self.up_blocks:\n            x = block(x)\n        x = self.conv_norm_out(x)\n        x = self.conv_act(x)\n        x = self.conv_out(x)\n        return x\n\n\nclass AutoencoderKL(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_block_types: Tuple[str] = (\"DownEncoderBlock2D\",),\n        up_block_types: Tuple[str] = (\"UpDecoderBlock2D\",),\n        block_out_channels: Tuple[int] = (64,),\n        layers_per_block: int = 1,\n        act_fn: str = \"silu\",\n        latent_channels: int = 4,\n        norm_num_groups: int = 32,\n        sample_size: int = 32,\n        scaling_factor: float = 0.18215,\n        shift_factor: Optional[float] = None,\n        force_upcast: bool = True,\n        use_quant_conv: bool = True,\n        use_post_quant_conv: bool = True,\n        mid_block_add_attention: bool = True,\n        **kwargs,\n    ):\n        super().__init__()\n        self.config = AutoencoderConfig(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            latent_channels=latent_channels,\n            scaling_factor=scaling_factor,\n            shift_factor=shift_factor,\n        )\n\n        self.encoder = Encoder(\n            in_channels=in_channels,\n            out_channels=latent_channels,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            norm_num_groups=norm_num_groups,\n            double_z=True,\n        )\n\n        self.decoder = Decoder(\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            norm_num_groups=norm_num_groups,\n        )\n\n        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None\n        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None\n\n    @property\n    def dtype(self):\n        return next(self.parameters()).dtype\n\n    def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:\n        if self.post_quant_conv is not None:\n            z = self.post_quant_conv(z)\n\n        dec = self.decoder(z)\n\n        if not return_dict:\n            return (dec,)\n\n        return AutoencoderKLOutput(sample=dec)\n"
  },
  {
    "path": "src/zimage/pipeline.py",
    "content": "\"\"\"Z-Image Pipeline.\"\"\"\n\nimport inspect\nfrom typing import List, Optional, Union\n\nfrom loguru import logger\nimport torch\n\nfrom config import (\n    BASE_IMAGE_SEQ_LEN,\n    BASE_SHIFT,\n    DEFAULT_CFG_TRUNCATION,\n    DEFAULT_GUIDANCE_SCALE,\n    DEFAULT_HEIGHT,\n    DEFAULT_INFERENCE_STEPS,\n    DEFAULT_MAX_SEQUENCE_LENGTH,\n    DEFAULT_WIDTH,\n    MAX_IMAGE_SEQ_LEN,\n    MAX_SHIFT,\n)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = BASE_IMAGE_SEQ_LEN,\n    max_seq_len: int = MAX_IMAGE_SEQ_LEN,\n    base_shift: float = BASE_SHIFT,\n    max_shift: float = MAX_SHIFT,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed.\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(f\"The scheduler does not support custom timestep schedules.\")\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(f\"The scheduler does not support custom sigmas schedules.\")\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n@torch.no_grad()\ndef generate(\n    transformer,\n    vae,\n    text_encoder,\n    tokenizer,\n    scheduler,\n    prompt: Union[str, List[str]],\n    height: int = DEFAULT_HEIGHT,\n    width: int = DEFAULT_WIDTH,\n    num_inference_steps: int = DEFAULT_INFERENCE_STEPS,\n    guidance_scale: float = DEFAULT_GUIDANCE_SCALE,\n    negative_prompt: Optional[Union[str, List[str]]] = None,\n    num_images_per_prompt: int = 1,\n    generator: Optional[torch.Generator] = None,\n    cfg_normalization: bool = False,\n    cfg_truncation: float = DEFAULT_CFG_TRUNCATION,\n    max_sequence_length: int = DEFAULT_MAX_SEQUENCE_LENGTH,\n    output_type: str = \"pil\",\n):\n    device = next(transformer.parameters()).device\n\n    if hasattr(vae, \"config\") and hasattr(vae.config, \"block_out_channels\"):\n        vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\n    else:\n        vae_scale_factor = 8\n    vae_scale = vae_scale_factor * 2\n\n    if height % vae_scale != 0:\n        raise ValueError(f\"Height must be divisible by {vae_scale} (got {height}).\")\n    if width % vae_scale != 0:\n        raise ValueError(f\"Width must be divisible by {vae_scale} (got {width}).\")\n\n    if isinstance(prompt, str):\n        batch_size = 1\n        prompt = [prompt]\n    else:\n        batch_size = len(prompt)\n\n    do_classifier_free_guidance = guidance_scale > 1.0\n    logger.info(f\"Generating image: {height}x{width}, steps={num_inference_steps}, cfg={guidance_scale}\")\n\n    formatted_prompts = []\n    for p in prompt:\n        messages = [{\"role\": \"user\", \"content\": p}]\n        formatted_prompt = tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,\n            add_generation_prompt=True,\n            enable_thinking=True,\n        )\n        formatted_prompts.append(formatted_prompt)\n\n    text_inputs = tokenizer(\n        formatted_prompts,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n\n    text_input_ids = text_inputs.input_ids.to(device)\n    prompt_masks = text_inputs.attention_mask.to(device).bool()\n\n    prompt_embeds = text_encoder(\n        input_ids=text_input_ids,\n        attention_mask=prompt_masks,\n        output_hidden_states=True,\n    ).hidden_states[-2]\n\n    prompt_embeds_list = []\n    for i in range(len(prompt_embeds)):\n        prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])\n\n    negative_prompt_embeds_list = []\n    if do_classifier_free_guidance:\n        if negative_prompt is None:\n            negative_prompt = [\"\" for _ in prompt]\n        elif isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n\n        neg_formatted = []\n        for p in negative_prompt:\n            messages = [{\"role\": \"user\", \"content\": p}]\n            formatted_prompt = tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=True,\n            )\n            neg_formatted.append(formatted_prompt)\n\n        neg_inputs = tokenizer(\n            neg_formatted,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        neg_input_ids = neg_inputs.input_ids.to(device)\n        neg_masks = neg_inputs.attention_mask.to(device).bool()\n\n        neg_embeds = text_encoder(\n            input_ids=neg_input_ids,\n            attention_mask=neg_masks,\n            output_hidden_states=True,\n        ).hidden_states[-2]\n\n        for i in range(len(neg_embeds)):\n            negative_prompt_embeds_list.append(neg_embeds[i][neg_masks[i]])\n\n    if num_images_per_prompt > 1:\n        prompt_embeds_list = [pe for pe in prompt_embeds_list for _ in range(num_images_per_prompt)]\n        if do_classifier_free_guidance:\n            negative_prompt_embeds_list = [\n                npe for npe in negative_prompt_embeds_list for _ in range(num_images_per_prompt)\n            ]\n\n    height_latent = 2 * (int(height) // vae_scale)\n    width_latent = 2 * (int(width) // vae_scale)\n    shape = (batch_size * num_images_per_prompt, transformer.in_channels, height_latent, width_latent)\n\n    latents = torch.randn(shape, generator=generator, device=device, dtype=torch.float32)\n\n    actual_batch_size = batch_size * num_images_per_prompt\n    image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)\n\n    mu = calculate_shift(\n        image_seq_len,\n        scheduler.config.get(\"base_image_seq_len\", 256),\n        scheduler.config.get(\"max_image_seq_len\", 4096),\n        scheduler.config.get(\"base_shift\", 0.5),\n        scheduler.config.get(\"max_shift\", 1.15),\n    )\n    scheduler.sigma_min = 0.0\n    scheduler_kwargs = {\"mu\": mu}\n    timesteps, num_inference_steps = retrieve_timesteps(\n        scheduler,\n        num_inference_steps,\n        device,\n        sigmas=None,\n        **scheduler_kwargs,\n    )\n\n    logger.info(f\"Sampling loop start: {num_inference_steps} steps\")\n\n    from tqdm import tqdm\n\n    # Denoising loop with progress bar\n    for i, t in enumerate(tqdm(timesteps, desc=\"Denoising\", total=len(timesteps))):\n        # If current t is 0 and it's the last step, skip computation\n        if t == 0 and i == len(timesteps) - 1:\n            logger.debug(f\"Step {i+1}/{num_inference_steps} | t: {t.item():.2f} | Skipping last step\")\n            continue\n\n        timestep = t.expand(latents.shape[0])\n        timestep = (1000 - timestep) / 1000\n        t_norm = timestep[0].item()\n\n        current_guidance_scale = guidance_scale\n        if do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1:\n            if t_norm > cfg_truncation:\n                current_guidance_scale = 0.0\n\n        apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0\n\n        if apply_cfg:\n            latents_typed = latents.to(\n                transformer.dtype if hasattr(transformer, \"dtype\") else next(transformer.parameters()).dtype\n            )\n            latent_model_input = latents_typed.repeat(2, 1, 1, 1)\n            prompt_embeds_model_input = prompt_embeds_list + negative_prompt_embeds_list\n            timestep_model_input = timestep.repeat(2)\n        else:\n            latent_model_input = latents.to(next(transformer.parameters()).dtype)\n            prompt_embeds_model_input = prompt_embeds_list\n            timestep_model_input = timestep\n\n        latent_model_input = latent_model_input.unsqueeze(2)\n        latent_model_input_list = list(latent_model_input.unbind(dim=0))\n\n        model_out_list = transformer(\n            latent_model_input_list,\n            timestep_model_input,\n            prompt_embeds_model_input,\n        )[0]\n\n        if apply_cfg:\n            pos_out = model_out_list[:actual_batch_size]\n            neg_out = model_out_list[actual_batch_size:]\n            noise_pred = []\n            for j in range(actual_batch_size):\n                pos = pos_out[j].float()\n                neg = neg_out[j].float()\n                pred = pos + current_guidance_scale * (pos - neg)\n\n                if cfg_normalization and float(cfg_normalization) > 0.0:\n                    ori_pos_norm = torch.linalg.vector_norm(pos)\n                    new_pos_norm = torch.linalg.vector_norm(pred)\n                    max_new_norm = ori_pos_norm * float(cfg_normalization)\n                    if new_pos_norm > max_new_norm:\n                        pred = pred * (max_new_norm / new_pos_norm)\n                noise_pred.append(pred)\n            noise_pred = torch.stack(noise_pred, dim=0)\n        else:\n            noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)\n\n        noise_pred = -noise_pred.squeeze(2)\n        latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]\n        assert latents.dtype == torch.float32\n\n    if output_type == \"latent\":\n        return latents\n\n    shift_factor = getattr(vae.config, \"shift_factor\", 0.0) or 0.0\n    latents = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor\n    image = vae.decode(latents, return_dict=False)[0]\n\n    if output_type == \"pil\":\n        from PIL import Image\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        image = (image * 255).round().astype(\"uint8\")\n        image = [Image.fromarray(img) for img in image]\n\n    return image\n"
  },
  {
    "path": "src/zimage/scheduler.py",
    "content": "\"\"\"FlowMatchEulerDiscreteScheduler implementation.\"\"\"\n\n# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py\nfrom dataclasses import dataclass\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\n\n\n@dataclass\nclass SchedulerOutput:\n    prev_sample: torch.FloatTensor\n\n\nclass SchedulerConfig:\n    def __init__(self, **kwargs):\n        self.__dict__.update(kwargs)\n\n    def get(self, key, default=None):\n        return self.__dict__.get(key, default)\n\n    def __getattr__(self, name):\n        return self.__dict__.get(name)\n\n\nclass FlowMatchEulerDiscreteScheduler:\n    \"\"\"Euler scheduler for flow matching.\"\"\"\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        shift: float = 1.0,\n        use_dynamic_shifting: bool = False,\n        **kwargs,\n    ):\n        self.num_train_timesteps = num_train_timesteps\n        self.shift = shift\n        self.use_dynamic_shifting = use_dynamic_shifting\n        self.config = SchedulerConfig(\n            num_train_timesteps=num_train_timesteps,\n            shift=shift,\n            use_dynamic_shifting=use_dynamic_shifting,\n        )\n\n        timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()\n        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)\n        sigmas = timesteps / num_train_timesteps\n\n        if not use_dynamic_shifting:\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n\n        self.timesteps = sigmas * num_train_timesteps\n        self.sigmas = sigmas.to(\"cpu\")\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n        self._step_index = None\n        self._begin_index = None\n\n    def set_timesteps(\n        self,\n        num_inference_steps: Optional[int] = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n        mu: Optional[float] = None,\n        timesteps: Optional[List[float]] = None,\n    ):\n        passed_timesteps = timesteps\n        if num_inference_steps is None:\n            num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)\n\n        self.num_inference_steps = num_inference_steps\n\n        if sigmas is None:\n            if timesteps is None:\n                timesteps = np.linspace(\n                    self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + 1\n                )[:-1]\n            sigmas = timesteps / self.num_train_timesteps\n        else:\n            sigmas = np.array(sigmas).astype(np.float32)\n\n        if self.use_dynamic_shifting:\n            sigmas = self.time_shift(mu, 1.0, sigmas)\n        else:\n            sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)\n\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)\n\n        if passed_timesteps is None:\n            timesteps = sigmas * self.num_train_timesteps\n        else:\n            timesteps = torch.from_numpy(passed_timesteps).to(dtype=torch.float32, device=device)\n\n        sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])\n\n        self.timesteps = timesteps\n        self.sigmas = sigmas\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n        pos = 1 if len(indices) > 1 else 0\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep):\n        if self._begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        sample: torch.FloatTensor,\n        return_dict: bool = True,\n        **kwargs,\n    ) -> Union[SchedulerOutput, Tuple]:\n        \"\"\"Predict the sample at the previous timestep.\"\"\"\n        if self._step_index is None:\n            self._init_step_index(timestep)\n\n        sample = sample.to(torch.float32)\n        sigma_idx = self._step_index\n        sigma = self.sigmas[sigma_idx]\n        sigma_next = self.sigmas[sigma_idx + 1]\n\n        dt = sigma_next - sigma\n        prev_sample = sample + dt * model_output\n        self._step_index += 1\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        if not return_dict:\n            return (prev_sample,)\n        return SchedulerOutput(prev_sample=prev_sample)\n\n    def _sigma_to_t(self, sigma):\n        return sigma * self.num_train_timesteps\n\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n"
  },
  {
    "path": "src/zimage/transformer.py",
    "content": "\"\"\"Z-Image Transformer.\"\"\"\n\nimport math\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_sequence\n\nfrom config import (\n    ADALN_EMBED_DIM,\n    FREQUENCY_EMBEDDING_SIZE,\n    MAX_PERIOD,\n    ROPE_AXES_DIMS,\n    ROPE_AXES_LENS,\n    ROPE_THETA,\n    SEQ_MULTI_OF,\n)\n\n\nclass TimestepEmbedder(nn.Module):\n    def __init__(self, out_size, mid_size=None, frequency_embedding_size=FREQUENCY_EMBEDDING_SIZE):\n        super().__init__()\n        if mid_size is None:\n            mid_size = out_size\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, mid_size, bias=True),\n            nn.SiLU(),\n            nn.Linear(mid_size, out_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=MAX_PERIOD):\n        with torch.amp.autocast(\"cuda\", enabled=False):\n            half = dim // 2\n            freqs = torch.exp(\n                -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half\n            )\n            args = t[:, None].float() * freqs[None]\n            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n            if dim % 2:\n                embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n            return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        weight_dtype = self.mlp[0].weight.dtype\n        if weight_dtype.is_floating_point:\n            t_freq = t_freq.to(weight_dtype)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-5):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n        return output * self.weight\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim: int, hidden_dim: int):\n        super().__init__()\n        self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n        self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n        self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n\n    def forward(self, x):\n        return self.w2(F.silu(self.w1(x)) * self.w3(x))\n\n\ndef apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))\n        freqs_cis = freqs_cis.unsqueeze(2)\n        x_out = torch.view_as_real(x * freqs_cis).flatten(3)\n        return x_out.type_as(x_in)\n\n\nclass ZImageAttention(nn.Module):\n    _attention_backend = None\n\n    def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: bool = True, eps: float = 1e-5):\n        super().__init__()\n        self.n_heads = n_heads\n        self.n_kv_heads = n_kv_heads\n        self.head_dim = dim // n_heads\n\n        self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)\n        self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)\n        self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)\n        self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])\n\n        self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None\n        self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        freqs_cis: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        query = self.to_q(hidden_states)\n        key = self.to_k(hidden_states)\n        value = self.to_v(hidden_states)\n\n        query = query.unflatten(-1, (self.n_heads, -1))\n        key = key.unflatten(-1, (self.n_kv_heads, -1))\n        value = value.unflatten(-1, (self.n_kv_heads, -1))\n\n        if self.norm_q is not None:\n            query = self.norm_q(query)\n        if self.norm_k is not None:\n            key = self.norm_k(key)\n\n        if freqs_cis is not None:\n            query = apply_rotary_emb(query, freqs_cis)\n            key = apply_rotary_emb(key, freqs_cis)\n\n        dtype = query.dtype\n        query, key = query.to(dtype), key.to(dtype)\n\n        # Dispatch\n        from utils.attention import dispatch_attention\n\n        hidden_states = dispatch_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend\n        )\n\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(dtype)\n\n        output = self.to_out[0](hidden_states)\n        return output\n\n\nclass ZImageTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        dim: int,\n        n_heads: int,\n        n_kv_heads: int,\n        norm_eps: float,\n        qk_norm: bool,\n        modulation=True,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.head_dim = dim // n_heads\n        self.layer_id = layer_id\n        self.modulation = modulation\n\n        self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)\n        self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))\n\n        self.attention_norm1 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)\n        self.attention_norm2 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)\n\n        if modulation:\n            self.adaLN_modulation = nn.ModuleList([nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)])\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        attn_mask: torch.Tensor,\n        freqs_cis: torch.Tensor,\n        adaln_input: Optional[torch.Tensor] = None,\n    ):\n        if self.modulation:\n            assert adaln_input is not None\n            scale_msa, gate_msa, scale_mlp, gate_mlp = (\n                self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)\n            )\n            gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()\n            scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp\n\n            attn_out = self.attention(\n                self.attention_norm1(x) * scale_msa,\n                attention_mask=attn_mask,\n                freqs_cis=freqs_cis,\n            )\n            x = x + gate_msa * self.attention_norm2(attn_out)\n            x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))\n        else:\n            attn_out = self.attention(\n                self.attention_norm1(x),\n                attention_mask=attn_mask,\n                freqs_cis=freqs_cis,\n            )\n            x = x + self.attention_norm2(attn_out)\n            x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))\n\n        return x\n\n\nclass FinalLayer(nn.Module):\n    def __init__(self, hidden_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, out_channels, bias=True)\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),\n        )\n\n    def forward(self, x, c):\n        scale = 1.0 + self.adaLN_modulation(c)\n        x = self.norm_final(x) * scale.unsqueeze(1)\n        x = self.linear(x)\n        return x\n\n\nclass RopeEmbedder:\n    def __init__(\n        self,\n        theta: float = ROPE_THETA,\n        axes_dims: List[int] = ROPE_AXES_DIMS,\n        axes_lens: List[int] = ROPE_AXES_LENS,\n    ):\n        self.theta = theta\n        self.axes_dims = axes_dims\n        self.axes_lens = axes_lens\n        assert len(axes_dims) == len(axes_lens)\n        self.freqs_cis = None\n\n    @staticmethod\n    def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = ROPE_THETA):\n        with torch.device(\"cpu\"):\n            freqs_cis = []\n            for i, (d, e) in enumerate(zip(dim, end)):\n                freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device=\"cpu\") / d))\n                timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)\n                freqs = torch.outer(timestep, freqs).float()\n                freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)\n                freqs_cis.append(freqs_cis_i)\n            return freqs_cis\n\n    def __call__(self, ids: torch.Tensor):\n        assert ids.ndim == 2\n        assert ids.shape[-1] == len(self.axes_dims)\n        device = ids.device\n\n        if self.freqs_cis is None:\n            self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)\n            self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]\n        else:\n            if self.freqs_cis[0].device != device:\n                self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]\n\n        result = []\n        for i in range(len(self.axes_dims)):\n            index = ids[:, i]\n            result.append(self.freqs_cis[i][index])\n        return torch.cat(result, dim=-1)\n\n\nclass ZImageTransformer2DModel(nn.Module):\n    def __init__(\n        self,\n        all_patch_size=(2,),\n        all_f_patch_size=(1,),\n        in_channels=16,\n        dim=3840,\n        n_layers=30,\n        n_refiner_layers=2,\n        n_heads=30,\n        n_kv_heads=30,\n        norm_eps=1e-5,\n        qk_norm=True,\n        cap_feat_dim=2560,\n        rope_theta=ROPE_THETA,\n        t_scale=1000.0,\n        axes_dims=ROPE_AXES_DIMS,\n        axes_lens=ROPE_AXES_LENS,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels\n        self.all_patch_size = all_patch_size\n        self.all_f_patch_size = all_f_patch_size\n        self.dim = dim\n        self.n_heads = n_heads\n        self.rope_theta = rope_theta\n        self.t_scale = t_scale\n\n        assert len(all_patch_size) == len(all_f_patch_size)\n\n        all_x_embedder = {}\n        all_final_layer = {}\n        for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):\n            x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)\n            all_x_embedder[f\"{patch_size}-{f_patch_size}\"] = x_embedder\n            final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)\n            all_final_layer[f\"{patch_size}-{f_patch_size}\"] = final_layer\n\n        self.all_x_embedder = nn.ModuleDict(all_x_embedder)\n        self.all_final_layer = nn.ModuleDict(all_final_layer)\n\n        self.noise_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)\n                for layer_id in range(n_refiner_layers)\n            ]\n        )\n\n        self.context_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)\n                for layer_id in range(n_refiner_layers)\n            ]\n        )\n\n        self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)\n        self.cap_embedder = nn.Sequential(\n            RMSNorm(cap_feat_dim, eps=norm_eps),\n            nn.Linear(cap_feat_dim, dim, bias=True),\n        )\n\n        self.x_pad_token = nn.Parameter(torch.empty((1, dim)))\n        self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))\n\n        self.layers = nn.ModuleList(\n            [\n                ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)\n                for layer_id in range(n_layers)\n            ]\n        )\n\n        head_dim = dim // n_heads\n        assert head_dim == sum(axes_dims)\n        self.axes_dims = axes_dims\n        self.axes_lens = axes_lens\n\n        self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)\n\n    def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:\n        pH = pW = patch_size\n        pF = f_patch_size\n        bsz = len(x)\n        assert len(size) == bsz\n        for i in range(bsz):\n            F, H, W = size[i]\n            ori_len = (F // pF) * (H // pH) * (W // pW)\n            x[i] = (\n                x[i][:ori_len]\n                .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)\n                .permute(6, 0, 3, 1, 4, 2, 5)\n                .reshape(self.out_channels, F, H, W)\n            )\n        return x\n\n    @staticmethod\n    def create_coordinate_grid(size, start=None, device=None):\n        if start is None:\n            start = (0 for _ in size)\n        axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]\n        grids = torch.meshgrid(axes, indexing=\"ij\")\n        return torch.stack(grids, dim=-1)\n\n    def patchify_and_embed(\n        self,\n        all_image: List[torch.Tensor],\n        all_cap_feats: List[torch.Tensor],\n        patch_size: int,\n        f_patch_size: int,\n    ):\n        pH = pW = patch_size\n        pF = f_patch_size\n        device = all_image[0].device\n\n        all_image_out = []\n        all_image_size = []\n        all_image_pos_ids = []\n        all_image_pad_mask = []\n        all_cap_pos_ids = []\n        all_cap_pad_mask = []\n        all_cap_feats_out = []\n\n        for _, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):\n            cap_ori_len = len(cap_feat)\n            cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF\n            cap_padded_pos_ids = self.create_coordinate_grid(\n                size=(cap_ori_len + cap_padding_len, 1, 1),\n                start=(1, 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            all_cap_pos_ids.append(cap_padded_pos_ids)\n            # pad mask\n            all_cap_pad_mask.append(\n                torch.cat(\n                    [\n                        torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),\n                        torch.ones((cap_padding_len,), dtype=torch.bool, device=device),\n                    ],\n                    dim=0,\n                )\n                if cap_padding_len > 0\n                else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)\n            )\n            # padded feature\n            all_cap_feats_out.append(\n                torch.cat(\n                    [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],\n                    dim=0,\n                )\n                if cap_padding_len > 0\n                else cap_feat\n            )\n\n            C, F, H, W = image.size()\n            all_image_size.append((F, H, W))\n            F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n\n            image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)\n            image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)\n\n            image_ori_len = len(image)\n            image_padding_len = (-image_ori_len) % SEQ_MULTI_OF\n\n            image_ori_pos_ids = self.create_coordinate_grid(\n                size=(F_tokens, H_tokens, W_tokens),\n                start=(cap_ori_len + cap_padding_len + 1, 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            image_padded_pos_ids = torch.cat(\n                [\n                    image_ori_pos_ids,\n                    self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)\n                    .flatten(0, 2)\n                    .repeat(image_padding_len, 1),\n                ],\n                dim=0,\n            )\n            all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)\n            # pad mask\n            image_pad_mask = torch.cat(\n                [\n                    torch.zeros((image_ori_len,), dtype=torch.bool, device=device),\n                    torch.ones((image_padding_len,), dtype=torch.bool, device=device),\n                ],\n                dim=0,\n            )\n            all_image_pad_mask.append(\n                image_pad_mask\n                if image_padding_len > 0\n                else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)\n            )\n            # padded feature\n            image_padded_feat = torch.cat(\n                [image, image[-1:].repeat(image_padding_len, 1)],\n                dim=0,\n            )\n            all_image_out.append(image_padded_feat if image_padding_len > 0 else image)\n\n        return (\n            all_image_out,\n            all_cap_feats_out,\n            all_image_size,\n            all_image_pos_ids,\n            all_cap_pos_ids,\n            all_image_pad_mask,\n            all_cap_pad_mask,\n        )\n\n    def forward(\n        self,\n        x: List[torch.Tensor],\n        t,\n        cap_feats: List[torch.Tensor],\n        patch_size=2,\n        f_patch_size=1,\n    ):\n        assert patch_size in self.all_patch_size\n        assert f_patch_size in self.all_f_patch_size\n\n        bsz = len(x)\n        device = x[0].device\n        t = t * self.t_scale\n        t = self.t_embedder(t)\n\n        (\n            x,\n            cap_feats,\n            x_size,\n            x_pos_ids,\n            cap_pos_ids,\n            x_inner_pad_mask,\n            cap_inner_pad_mask,\n        ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)\n\n        x_item_seqlens = [len(_) for _ in x]\n        assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)\n        x_max_item_seqlen = max(x_item_seqlens)\n\n        x = torch.cat(x, dim=0)\n        x = self.all_x_embedder[f\"{patch_size}-{f_patch_size}\"](x)\n\n        adaln_input = t.type_as(x)\n        x[torch.cat(x_inner_pad_mask)] = self.x_pad_token\n        x = list(x.split(x_item_seqlens, dim=0))\n        x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))\n\n        x = pad_sequence(x, batch_first=True, padding_value=0.0)\n        x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)\n        # Clarify the length matches to satisfy Dynamo due to \"Symbolic Shape Inference\" to avoid compilation errors\n        x_freqs_cis = x_freqs_cis[:, : x.shape[1]]\n\n        x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(x_item_seqlens):\n            x_attn_mask[i, :seq_len] = 1\n\n        for layer in self.noise_refiner:\n            x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)\n\n        cap_item_seqlens = [len(_) for _ in cap_feats]\n        assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)\n        cap_max_item_seqlen = max(cap_item_seqlens)\n\n        cap_feats = torch.cat(cap_feats, dim=0)\n        cap_feats = self.cap_embedder(cap_feats)\n        cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token\n        cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))\n        cap_freqs_cis = list(\n            self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)\n        )\n\n        cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)\n        cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)\n        cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]  # same for dynamo compatibility\n\n        cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(cap_item_seqlens):\n            cap_attn_mask[i, :seq_len] = 1\n\n        for layer in self.context_refiner:\n            cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)\n\n        unified = []\n        unified_freqs_cis = []\n        for i in range(bsz):\n            x_len = x_item_seqlens[i]\n            cap_len = cap_item_seqlens[i]\n            unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))\n            unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))\n        unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]\n        assert unified_item_seqlens == [len(_) for _ in unified]\n        unified_max_item_seqlen = max(unified_item_seqlens)\n\n        unified = pad_sequence(unified, batch_first=True, padding_value=0.0)\n        unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)\n        unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)\n        for i, seq_len in enumerate(unified_item_seqlens):\n            unified_attn_mask[i, :seq_len] = 1\n\n        for layer in self.layers:\n            unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)\n\n        unified = self.all_final_layer[f\"{patch_size}-{f_patch_size}\"](unified, adaln_input)\n        unified = list(unified.unbind(dim=0))\n        x = self.unpatchify(unified, x_size, patch_size, f_patch_size)\n\n        return x, {}\n"
  }
]