Full Code of Tongyi-MAI/Z-Image for AI

main 26f23eda626f cached
24 files
138.3 KB
36.1k tokens
125 symbols
1 requests
Download .txt
Repository: Tongyi-MAI/Z-Image
Branch: main
Commit: 26f23eda626f
Files: 24
Total size: 138.3 KB

Directory structure:
gitextract_m78kzmq8/

├── .gitignore
├── LICENSE
├── README.md
├── batch_inference.py
├── inference.py
├── pyproject.toml
└── src/
    ├── __init__.py
    ├── config/
    │   ├── __init__.py
    │   ├── inference.py
    │   ├── manifests/
    │   │   ├── README.md
    │   │   └── z-image-turbo.txt
    │   └── model.py
    ├── tools/
    │   ├── __init__.py
    │   └── generate_manifest.py
    ├── utils/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── helpers.py
    │   ├── import_utils.py
    │   └── loader.py
    └── zimage/
        ├── __init__.py
        ├── autoencoder.py
        ├── pipeline.py
        ├── scheduler.py
        └── transformer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class

# C extensions
*.so
outputs/
prompts/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#   pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
#   https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/

# pixi
#   Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
#   Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
#   in the .venv directory. It is recommended not to include this directory in version control.
.pixi

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/

# Visual Studio Code
#  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 
#  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
#  and can be added to the global gitignore or merged into this file. However, if you prefer, 
#  you could uncomment the following to ignore the entire vscode folder
# .vscode/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

# Cursor
#  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
#  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
#  refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore

# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/

# Z-Image
ckpts/
.isort.cfg
.pre-commit-config.yaml
*.DS_Store

# Ignore generated images
/*.png

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<h1 align="center">⚡️- Image<br><sub><sup>An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer</sup></sub></h1>

<div align="center">

[![Official Site](https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage)](https://tongyi-mai.github.io/Z-Image-blog/)&#160;
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint-Z--Image-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image)&#160;
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint-Z--Image--Turbo-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo)&#160;
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Online_Demo-Z--Image-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image)&#160;
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Online_Demo-Z--Image--Turbo-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)&#160;
[![ModelScope Model](https://img.shields.io/badge/🤖%20Checkpoint-Z--Image-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)&#160;
[![ModelScope Model](https://img.shields.io/badge/🤖%20Checkpoint-Z--Image--Turbo-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)&#160;
[![ModelScope Space](https://img.shields.io/badge/🤖%20Online_Demo-Z--Image-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)&#160;
[![ModelScope Space](https://img.shields.io/badge/🤖%20Online_Demo-Z--Image--Turbo-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster)&#160;
[![Art Gallery PDF](https://img.shields.io/badge/%F0%9F%96%BC%20Art_Gallery-PDF-ff69b4)](assets/Z-Image-Gallery.pdf)&#160;
[![Web Art Gallery](https://img.shields.io/badge/%F0%9F%8C%90%20Web_Art_Gallery-online-00bfff)](https://modelscope.cn/studios/Tongyi-MAI/Z-Image-Gallery/summary)&#160;
<a href="https://arxiv.org/abs/2511.22699" target="_blank"><img src="https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv" height="21px"></a>


Welcome to the official repository for the Z-Image(造相)project!

</div>



## ✨ Z-Image

Z-Image is a powerful and highly efficient image generation model family with **6B** parameters. Currently there are four variants:

- 🚀 **Z-Image-Turbo** – A distilled version of Z-Image that matches or exceeds leading competitors with only **8 NFEs** (Number of Function Evaluations). It offers **⚡️sub-second inference latency⚡️** on enterprise-grade H800 GPUs and fits comfortably within **16G VRAM consumer devices**. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.

- 🎨 **Z-Image** – The foundation model behind Z-Image-Turbo. Z-Image focuses on **high-quality generation**, **rich aesthetics**, **strong diversity**, and **controllability**, well-suited for creative generation, **fine-tuning**, and downstream development. It supports a wide range of artistic styles, effective negative prompting, and high diversity across identities, poses, compositions, and layouts.

- 🧱 **Z-Image-Omni-Base** – The versatile foundation model capable of both **generation and editing tasks**. By releasing this checkpoint, we aim to unlock the full potential for community-driven fine-tuning and custom development, providing the most "raw" and diverse starting point for the open-source community.

- ✍️ **Z-Image-Edit** – A variant fine-tuned on Z-Image specifically for image editing tasks. It supports creative image-to-image generation with impressive instruction-following capabilities, allowing for precise edits based on natural language prompts.

### 📣 News

*   **[2026-01-27]** 🔥 **Z-Image is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image). Try our [online demo](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)!
*   **[2025-12-08]** 🏆 Z-Image-Turbo ranked 8th overall on the **Artificial Analysis Text-to-Image Leaderboard**, making it the 🥇 <strong style="color: #FFC300;">#1 open-source model</strong>! [Check out the full leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image).
*   **[2025-12-01]** 🎉 Our technical report for Z-Image is now available on [arXiv](https://arxiv.org/abs/2511.22699).
*   **[2025-11-26]** 🔥 **Z-Image-Turbo is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo). Try our [online demo](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)!

### 📥 Model Zoo

| Model | Pre-Training | SFT | RL | Step | CFG | Task | Visual Quality | Diversity | Fine-Tunability | Hugging Face | ModelScope |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **Z-Image-Omni-Base** | ✅ | ❌ | ❌ | 50 | ✅ | Gen. / Editing | Medium | High | Easy | *To be released* | *To be released* |
| **Z-Image** | ✅ | ✅ | ❌ | 50 | ✅ | Gen. | High | Medium | Easy | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint%20-Z--Image-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image) <br> [![Hugging Face Space](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Z--Image-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image) | [![ModelScope Model](https://img.shields.io/badge/🤖%20%20Checkpoint-Z--Image-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image) <br> [![ModelScope Space](https://img.shields.io/badge/%F0%9F%A4%96%20Demo-Z--Image-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster) |
| **Z-Image-Turbo** | ✅ | ✅ | ✅ | 8 | ❌ | Gen. | Very High | Low | N/A | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Checkpoint%20-Z--Image--Turbo-yellow)](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) <br> [![Hugging Face Space](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Z--Image--Turbo-blue)](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo) | [![ModelScope Model](https://img.shields.io/badge/🤖%20%20Checkpoint-Z--Image--Turbo-624aff)](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) <br> [![ModelScope Space](https://img.shields.io/badge/%F0%9F%A4%96%20Demo-Z--Image--Turbo-17c7a7)](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster) |
| **Z-Image-Edit** | ✅ | ✅ | ❌ | 50 | ✅ | Editing | High | Medium | Easy | *To be released* | *To be released* |

The figure below illustrates at which training stage each model is produced.

![Training Pipeline of Z-Image](assets/training_pipeline.jpg)

### 🖼️ Showcase

📸 **Photorealistic Quality**: **Z-Image-Turbo** delivers strong photorealistic image generation while maintaining excellent aesthetic quality.

![Showcase of Z-Image on Photo-realistic image Generation](assets/showcase_realistic.png)

📖 **Accurate Bilingual Text Rendering**: **Z-Image-Turbo** excels at accurately rendering complex Chinese and English text.

![Showcase of Z-Image on Bilingual Text Rendering](assets/showcase_rendering.png)

💡  **Prompt Enhancing & Reasoning**: Prompt Enhancer empowers the model with reasoning capabilities, enabling it to transcend surface-level descriptions and tap into underlying world knowledge.

![reasoning.jpg](assets/reasoning.png)

🧠 **Creative Image Editing**: **Z-Image-Edit** shows a strong understanding of bilingual editing instructions, enabling imaginative and flexible image transformations.

![Showcase of Z-Image-Edit on Image Editing](assets/showcase_editing.png)

### 🏗️ Model Architecture
We adopt a **Scalable Single-Stream DiT** (S3-DiT) architecture. In this setup, text, visual semantic tokens, and image VAE tokens are concatenated at the sequence level to serve as a unified input stream, maximizing parameter efficiency compared to dual-stream approaches.

![Architecture of Z-Image and Z-Image-Edit](assets/architecture.webp)

### 📈 Performance

Z-Image-Turbo's performance has been validated on multiple independent benchmarks, where it consistently demonstrates state-of-the-art results, especially as the leading open-source model.

#### Artificial Analysis Text-to-Image Leaderboard
On the highly competitive [Artificial Analysis Leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image), Z-Image-Turbo ranked **8th overall** and secured the top position as the 🥇 <strong style="color: gold;">#1 Open-Source Model</strong>, outperforming all other open-source alternatives.


<p align="center">
  <a href="https://artificialanalysis.ai/image/leaderboard/text-to-image">
    <img src="assets/image_arena_all.jpg" alt="Z-Image Rank on Artificial Analysis Leaderboard"/><br />
    <span style="font-size:1.05em; cursor:pointer; text-decoration:underline;"> Artificial Analysis Leaderboard</span>
  </a>
</p>

<p align="center">
  <a href="https://artificialanalysis.ai/image/leaderboard/text-to-image">
    <img src="assets/image_arena_os.jpg" alt="Z-Image Rank on Artificial Analysis Leaderboard (Open-Source Model Only)"/><br />
    <span style="font-size:1.05em; cursor:pointer; text-decoration:underline;"> Artificial Analysis Leaderboard (Open-Source Model Only)</span>
  </a>
</p>

#### Alibaba AI Arena Text-to-Image Leaderboard
According to the Elo-based Human Preference Evaluation on [*Alibaba AI Arena*](https://aiarena.alibaba-inc.com/corpora/arena/leaderboard?arenaType=T2I), Z-Image-Turbo also achieves state-of-the-art results among open-source models and shows highly competitive performance against leading proprietary models.

<p align="center">
  <a href="https://aiarena.alibaba-inc.com/corpora/arena/leaderboard?arenaType=T2I">
    <img src="assets/leaderboard.png" alt="Z-Image Elo Rating on AI Arena"/><br />
    <span style="font-size:1.05em; cursor:pointer; text-decoration:underline;"> Alibaba AI Arena Text-to-Image Leaderboard</span>
  </a>
</p>


### 🚀 Quick Start
#### (1) PyTorch Native Inference
Build a virtual environment you like and then install the dependencies:
```bash
pip install -e .
```
Then run the following code to generate an image:
```bash
python inference.py
```

#### (2) Diffusers Inference
Install the latest version of diffusers, use the following command:
<details>
  <summary>Click here for details for why you need to install diffusers from source</summary>

  We have submitted two pull requests ([#12703](https://github.com/huggingface/diffusers/pull/12703) and [#12715](https://github.com/huggingface/diffusers/pull/12715)) to the 🤗 diffusers repository to add support for Z-Image. Both PRs have been merged into the latest official diffusers release.
  Therefore, you need to install diffusers from source for the latest features and Z-Image support.

</details>

```bash
pip install git+https://github.com/huggingface/diffusers
```

<details>
<summary><b>Z-Image-Turbo</b> - Click to expand</summary>

Then, try the following code to generate an image:
```python
import torch
from diffusers import ZImagePipeline

# 1. Load the pipeline
# Use bfloat16 for optimal performance on supported GPUs
pipe = ZImagePipeline.from_pretrained(
    "Tongyi-MAI/Z-Image-Turbo",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=False,
)
pipe.to("cuda")

# [Optional] Attention Backend
# Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported:
# pipe.transformer.set_attention_backend("flash")    # Enable Flash-Attention-2
# pipe.transformer.set_attention_backend("_flash_3") # Enable Flash-Attention-3

# [Optional] Model Compilation
# Compiling the DiT model accelerates inference, but the first run will take longer to compile.
# pipe.transformer.compile()

# [Optional] CPU Offloading
# Enable CPU offloading for memory-constrained devices.
# pipe.enable_model_cpu_offload()

prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."

# 2. Generate Image
image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    num_inference_steps=9,  # This actually results in 8 DiT forwards
    guidance_scale=0.0,     # Guidance should be 0 for the Turbo models
    generator=torch.Generator("cuda").manual_seed(42),
).images[0]

image.save("example.png")
```

</details>

<details>
<summary><b>Z-Image</b> - Click to expand</summary>

Recommended Parameters:
- **Resolution:** 512×512 to 2048×2048 (total pixel area, any aspect ratio)
- **Guidance scale:** 3.0 – 5.0
- **Inference steps:** 28 – 50
- **Negative prompts:** Strongly recommended for better control
- **CFG normalization:** `False` for general stylism, `True` for realism

Then, try the following code to generate an image:
```python
import torch
from diffusers import ZImagePipeline

# Load the pipeline
pipe = ZImagePipeline.from_pretrained(
    "Tongyi-MAI/Z-Image",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=False,
)
pipe.to("cuda")

# Generate image
prompt = "两名年轻亚裔女性紧密站在一起,背景为朴素的灰色纹理墙面,可能是室内地毯地面。左侧女性留着长卷发,身穿藏青色毛衣,左袖有奶油色褶皱装饰,内搭白色立领衬衫,下身白色裤子;佩戴小巧金色耳钉,双臂交叉于背后。右侧女性留直肩长发,身穿奶油色卫衣,胸前印有"Tun the tables"字样,下方为"New ideas",搭配白色裤子;佩戴银色小环耳环,双臂交叉于胸前。两人均面带微笑直视镜头。照片,自然光照明,柔和阴影,以藏青、奶油白为主的中性色调,休闲时尚摄影,中等景深,面部和上半身对焦清晰,姿态放松,表情友好,室内环境,地毯地面,纯色背景。"
negative_prompt = "" # Optional, but would be powerful when you want to remove some unwanted content

image = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=1280,
    width=720,
    cfg_normalization=False,
    num_inference_steps=50,
    guidance_scale=4,
    generator=torch.Generator("cuda").manual_seed(42),
).images[0]

image.save("example.png")
```

</details>

## 🔬 Decoupled-DMD: The Acceleration Magic Behind Z-Image

[![arXiv](https://img.shields.io/badge/arXiv-2511.22677-b31b1b.svg)](https://arxiv.org/abs/2511.22677)

Decoupled-DMD is the core few-step distillation algorithm that empowers the 8-step Z-Image model.

Our core insight in Decoupled-DMD  is that the success of existing DMD (Distribution Matching Distillation) methods is the result of two independent, collaborating mechanisms:

-   **CFG Augmentation (CA)**: The primary **engine** 🚀 driving the distillation process, a factor largely overlooked in previous work.
-   **Distribution Matching (DM)**: Acts more as a **regularizer** ⚖️, ensuring the stability and quality of the generated output.

By recognizing and decoupling these two mechanisms, we were able to study and optimize them in isolation. This ultimately motivated us to develop an improved distillation process that significantly enhances the performance of few-step generation.

![Diagram of Decoupled-DMD](assets/decoupled-dmd.webp)

## 🤖 DMDR: Fusing DMD with Reinforcement Learning

[![arXiv](https://img.shields.io/badge/arXiv-2511.13649-b31b1b.svg)](https://arxiv.org/abs/2511.13649)

Building upon the strong foundation of Decoupled-DMD, our 8-step Z-Image model has already demonstrated exceptional capabilities. To achieve further improvements in terms of semantic alignment, aesthetic quality, and structural coherence—while producing images with richer high-frequency details—we present **DMDR**.

Our core insight behind DMDR is that Reinforcement Learning (RL) and Distribution Matching Distillation (DMD) can be synergistically integrated during the post-training of few-step models. We demonstrate that:

-   **RL Unlocks the Performance of DMD** 🚀
-   **DMD Effectively Regularizes RL** ⚖️

![Diagram of DMDR](assets/DMDR.webp)

## 🎉 Community Works

- [Cache-DiT](https://github.com/vipshop/cache-dit) provides inference acceleration for **Z-Image** and **Z-Image-ControlNet** via DBCache, Context Parallelism and Tensor Parallelism. It achieves nearly **4x** speedup on 4 GPUs with negligible precision loss. Please visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples) for more details.
- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).
- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).
- [LeMiCa](https://github.com/UnicomAI/LeMiCa) provides a training-free, timestep-level acceleration method that conveniently speeds up Z-Image inference. For more details, see [LeMiCa4Z-Image](https://github.com/UnicomAI/LeMiCa/tree/main/LeMiCa4Z-Image).
- [ComfyUI ZImageLatent](https://github.com/HellerCommaA/ComfyUI-ZImageLatent) provdes an easy to use latent of the official Z-Image resolutions.
- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) has provided more support for Z-Image, including LoRA training, full training, distillation training, and low-VRAM inference. Please refer to the [document](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Model_Details/Z-Image.md) of DiffSynth-Studio.
- [vllm-omni](https://github.com/vllm-project/vllm-omni), a framework that extends its support for omni-modality model fast inference and serving, now [supports](https://github.com/vllm-project/vllm-omni/blob/main/docs/models/supported_models.md) Z-Image.
- [SGLang-Diffusion](https://lmsys.org/blog/2025-11-07-sglang-diffusion/) brings SGLang's state-of-the-art performance to accelerate image and video generation for diffusion models, now [supporting](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py) Z-Image.
- [Candle](https://github.com/huggingface/candle) is a minimalist machine learning (ML) framework launched by Huggingface for Rust, which now [supports](https://github.com/huggingface/candle/pull/3261) Z-Image.
- [MeanCache](https://github.com/UnicomAI/MeanCache), a training-free inference acceleration method for Flow Matching models by China Unicom Data Science and Artificial Intelligence Research Institute. Delivers up to **3.7x** speedup for **Z-Image** generation with plug-and-play integration while preserving output quality.

## 🚀 Star History

[![Star History Chart](https://api.star-history.com/svg?repos=Tongyi-MAI/Z-Image&type=date&legend=top-left)](https://www.star-history.com/#Tongyi-MAI/Z-Image&type=date&legend=top-left)


## 📜 Citation

If you find our work useful in your research, please consider citing:

```bibtex
@article{team2025zimage,
  title={Z-Image: An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer},
  author={Z-Image Team},
  journal={arXiv preprint arXiv:2511.22699},
  year={2025}
}

@article{liu2025decoupled,
  title={Decoupled DMD: CFG Augmentation as the Spear, Distribution Matching as the Shield},
  author={Dongyang Liu and Peng Gao and David Liu and Ruoyi Du and Zhen Li and Qilong Wu and Xin Jin and Sihan Cao and Shifeng Zhang and Hongsheng Li and Steven Hoi},
  journal={arXiv preprint arXiv:2511.22677},
  year={2025}
}

@article{jiang2025distribution,
  title={Distribution Matching Distillation Meets Reinforcement Learning},
  author={Jiang, Dengyang and Liu, Dongyang and Wang, Zanyi and Wu, Qilong and Jin, Xin and Liu, David and Li, Zhen and Wang, Mengmeng and Gao, Peng and Yang, Harry},
  journal={arXiv preprint arXiv:2511.13649},
  year={2025}
}

```

## 🤝 We're Hiring!

We're actively looking for **Research Scientists**, **Engineers**, and **Interns** to work on foundational generative models and their applications. Interested candidates please send your resume to: **jingpeng.gp@alibaba-inc.com**


================================================
FILE: batch_inference.py
================================================
"""Batch prompt inference for Z-Image."""

import os
from pathlib import Path
import time

import torch

from inference import ensure_weights
from utils import AttentionBackend, load_from_local_dir, set_attention_backend
from zimage import generate


def read_prompts(path: str) -> list[str]:
    """Read prompts from a text file (one per line, empty lines skipped)."""

    prompt_path = Path(path)
    if not prompt_path.exists():
        raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
    with prompt_path.open("r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f if line.strip()]
    if not prompts:
        raise ValueError(f"No prompts found in {prompt_path}")
    return prompts


PROMPTS = read_prompts(os.environ.get("PROMPTS_FILE", "prompts/prompt1.txt"))


def slugify(text: str, max_len: int = 60) -> str:
    """Create a filesystem-safe slug from the prompt."""

    slug = "".join(ch.lower() if ch.isalnum() else "-" for ch in text)
    slug = "-".join(part for part in slug.split("-") if part)
    return slug[:max_len].rstrip("-") or "prompt"


def select_device() -> str:
    """Choose the best available device without repeating detection logic."""

    if torch.cuda.is_available():
        print("Chosen device: cuda")
        return "cuda"
    try:
        import torch_xla.core.xla_model as xm

        device = xm.xla_device()
        print("Chosen device: tpu")
        return device
    except (ImportError, RuntimeError):
        if torch.backends.mps.is_available():
            print("Chosen device: mps")
            return "mps"
        print("Chosen device: cpu")
        return "cpu"


def main():
    model_path = ensure_weights("ckpts/Z-Image-Turbo")
    dtype = torch.bfloat16
    compile = False
    height = 1024
    width = 1024
    num_inference_steps = 8
    guidance_scale = 0.0
    attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
    output_dir = Path("outputs")
    output_dir.mkdir(exist_ok=True)

    device = select_device()

    components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
    AttentionBackend.print_available_backends()
    set_attention_backend(attn_backend)
    print(f"Chosen attention backend: {attn_backend}")

    for idx, prompt in enumerate(PROMPTS, start=1):
        output_path = output_dir / f"prompt-{idx:02d}-{slugify(prompt)}.png"
        seed = 42 + idx - 1
        generator = torch.Generator(device).manual_seed(seed)

        start_time = time.time()
        images = generate(
            prompt=prompt,
            **components,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        )
        elapsed = time.time() - start_time
        images[0].save(output_path)
        print(f"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds")

    print("Done.")


if __name__ == "__main__":
    main()


================================================
FILE: inference.py
================================================
"""Z-Image PyTorch Native Inference."""

import os
import time
import warnings

import torch

warnings.filterwarnings("ignore")
from utils import AttentionBackend, ensure_model_weights, load_from_local_dir, set_attention_backend
from zimage import generate


def main():
    model_path = ensure_model_weights("ckpts/Z-Image-Turbo", verify=False)  # True to verify with md5
    dtype = torch.bfloat16
    compile = False  # default False for compatibility
    output_path = "example.png"
    height = 1024
    width = 1024
    num_inference_steps = 8
    guidance_scale = 0.0
    seed = 42
    attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
    prompt = (
        "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. "
        "Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. "
        "Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, "
        "silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
    )

    # Device selection priority: cuda -> tpu -> mps -> cpu
    if torch.cuda.is_available():
        device = "cuda"
        print("Chosen device: cuda")
    else:
        try:
            import torch_xla
            import torch_xla.core.xla_model as xm

            device = xm.xla_device()
            print("Chosen device: tpu")
        except (ImportError, RuntimeError):
            if torch.backends.mps.is_available():
                device = "mps"
                print("Chosen device: mps")
            else:
                device = "cpu"
                print("Chosen device: cpu")
    # Load models
    components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
    AttentionBackend.print_available_backends()
    set_attention_backend(attn_backend)
    print(f"Chosen attention backend: {attn_backend}")

    # Gen an image
    start_time = time.time()
    images = generate(
        prompt=prompt,
        **components,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=torch.Generator(device).manual_seed(seed),
    )
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")
    images[0].save(output_path)

    ### !! For best speed performance, recommend to use `_flash_3` backend and set `compile=True`
    ### This would give you sub-second generation speed on Hopper GPU (H100/H200/H800) after warm-up


if __name__ == "__main__":
    main()


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "zimage-native"
version = "0.1.0"
description = "Z-Image PyTorch Native Implementation"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.5.0",
    "transformers>=4.51.0",
    "safetensors",
    "loguru",
    "pillow",
    "accelerate",
    "huggingface_hub>=0.25.0"
]

[project.optional-dependencies]
dev = [
    "black",
    "isort",
    "ruff"
]

[tool.setuptools.packages.find]
where = ["src"]


================================================
FILE: src/__init__.py
================================================
"""Z-Image Native Implementation."""

from .utils import load_from_local_dir
from .zimage import ZImageTransformer2DModel, generate

__version__ = "0.1.0"

__all__ = [
    "ZImageTransformer2DModel",
    "generate",
    "load_from_local_dir",
]


================================================
FILE: src/config/__init__.py
================================================
"""Z-Image Configuration."""

from .inference import (
    DEFAULT_CFG_TRUNCATION,
    DEFAULT_GUIDANCE_SCALE,
    DEFAULT_HEIGHT,
    DEFAULT_INFERENCE_STEPS,
    DEFAULT_MAX_SEQUENCE_LENGTH,
    DEFAULT_WIDTH,
)
from .model import (
    ADALN_EMBED_DIM,
    BASE_IMAGE_SEQ_LEN,
    BASE_SHIFT,
    BYTES_PER_GB,
    DEFAULT_LOAD_DEVICE,
    DEFAULT_LOAD_DTYPE_STR,
    DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
    DEFAULT_SCHEDULER_SHIFT,
    DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
    DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
    DEFAULT_TRANSFORMER_DIM,
    DEFAULT_TRANSFORMER_F_PATCH_SIZE,
    DEFAULT_TRANSFORMER_IN_CHANNELS,
    DEFAULT_TRANSFORMER_N_HEADS,
    DEFAULT_TRANSFORMER_N_KV_HEADS,
    DEFAULT_TRANSFORMER_N_LAYERS,
    DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
    DEFAULT_TRANSFORMER_NORM_EPS,
    DEFAULT_TRANSFORMER_PATCH_SIZE,
    DEFAULT_TRANSFORMER_QK_NORM,
    DEFAULT_TRANSFORMER_T_SCALE,
    DEFAULT_VAE_IN_CHANNELS,
    DEFAULT_VAE_LATENT_CHANNELS,
    DEFAULT_VAE_NORM_NUM_GROUPS,
    DEFAULT_VAE_OUT_CHANNELS,
    DEFAULT_VAE_SCALE_FACTOR,
    DEFAULT_VAE_SCALING_FACTOR,
    FREQUENCY_EMBEDDING_SIZE,
    MAX_IMAGE_SEQ_LEN,
    MAX_PERIOD,
    MAX_SHIFT,
    ROPE_AXES_DIMS,
    ROPE_AXES_LENS,
    ROPE_THETA,
    SEQ_MULTI_OF,
)

__all__ = [
    "ADALN_EMBED_DIM",
    "SEQ_MULTI_OF",
    "ROPE_THETA",
    "ROPE_AXES_DIMS",
    "ROPE_AXES_LENS",
    "FREQUENCY_EMBEDDING_SIZE",
    "MAX_PERIOD",
    "BASE_IMAGE_SEQ_LEN",
    "MAX_IMAGE_SEQ_LEN",
    "BASE_SHIFT",
    "MAX_SHIFT",
    "DEFAULT_VAE_SCALE_FACTOR",
    "DEFAULT_VAE_IN_CHANNELS",
    "DEFAULT_VAE_OUT_CHANNELS",
    "DEFAULT_VAE_LATENT_CHANNELS",
    "DEFAULT_VAE_NORM_NUM_GROUPS",
    "DEFAULT_VAE_SCALING_FACTOR",
    "DEFAULT_TRANSFORMER_PATCH_SIZE",
    "DEFAULT_TRANSFORMER_F_PATCH_SIZE",
    "DEFAULT_TRANSFORMER_IN_CHANNELS",
    "DEFAULT_TRANSFORMER_DIM",
    "DEFAULT_TRANSFORMER_N_LAYERS",
    "DEFAULT_TRANSFORMER_N_REFINER_LAYERS",
    "DEFAULT_TRANSFORMER_N_HEADS",
    "DEFAULT_TRANSFORMER_N_KV_HEADS",
    "DEFAULT_TRANSFORMER_NORM_EPS",
    "DEFAULT_TRANSFORMER_QK_NORM",
    "DEFAULT_TRANSFORMER_CAP_FEAT_DIM",
    "DEFAULT_TRANSFORMER_T_SCALE",
    "DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS",
    "DEFAULT_SCHEDULER_SHIFT",
    "DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING",
    "DEFAULT_LOAD_DEVICE",
    "DEFAULT_LOAD_DTYPE_STR",
    "BYTES_PER_GB",
    "DEFAULT_HEIGHT",
    "DEFAULT_WIDTH",
    "DEFAULT_INFERENCE_STEPS",
    "DEFAULT_GUIDANCE_SCALE",
    "DEFAULT_CFG_TRUNCATION",
    "DEFAULT_MAX_SEQUENCE_LENGTH",
]


================================================
FILE: src/config/inference.py
================================================
"""Inference-specific configuration for Z-Image."""

DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
DEFAULT_INFERENCE_STEPS = 8
DEFAULT_GUIDANCE_SCALE = 0.0
DEFAULT_CFG_TRUNCATION = 1.0
DEFAULT_MAX_SEQUENCE_LENGTH = 512


================================================
FILE: src/config/manifests/README.md
================================================
# Model Manifests

This directory contains manifest files for different Z-Image model variants.

## Purpose

Manifest files list all required files for each model, optionally with MD5 checksums for integrity verification.

## File Naming Convention

- `z-image-turbo.txt` - Z-Image Turbo model
- Custom models: `{model-name}.txt`

## Format

### Standard Format (with MD5 - Recommended)

```txt
# Z-Image Model Manifest
# Format: <md5hash>  <filepath>
# Generated automatically - DO NOT edit manually

5e3226ed72a9a4a080f2a4ca78b98ddc  model_index.json
ca682fcc6c5a94cf726b7187e64b9411  scheduler/scheduler_config.json
1e97eb35d9d0b6aa60c58a8df8d7d99a  text_encoder/config.json
30b85686b9a9b002e012494fadc027cb  text_encoder/model-00001-of-00003.safetensors
...
```

**Verification Behavior:**
- `verify=False`: Default, only checks file existence, ignores MD5 (fast)
- `verify=True`: Checks existence AND verifies MD5 checksums (thorough)

## Usage

The manifest file is automatically selected based on the model directory name:

```python
# Auto-detects manifest from "Z-Image-Turbo" -> uses z-image-turbo.txt
model_path = ensure_model_weights("ckpts/Z-Image-Turbo")

# Explicit manifest
model_path = ensure_model_weights("ckpts/Z-Image-Turbo", manifest_name="z-image-turbo.txt")
```

## Generating Manifests

Use the provided tool to generate manifests:

```bash
# Generate with MD5 checksums (auto-saves to this directory)
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo

# Generate without checksums (faster, not recommended)
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums

# With verbose output
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --verbose

# Custom output path
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --output custom.txt
```

## Available Manifests

- **z-image-turbo.txt** - Z-Image Turbo model


================================================
FILE: src/config/manifests/z-image-turbo.txt
================================================
# Z-Image Model Manifest
# Format: <md5hash>  <filepath>
# Generated automatically - DO NOT edit manually

5e3226ed72a9a4a080f2a4ca78b98ddc  model_index.json
ca682fcc6c5a94cf726b7187e64b9411  scheduler/scheduler_config.json
1e97eb35d9d0b6aa60c58a8df8d7d99a  text_encoder/config.json
30b85686b9a9b002e012494fadc027cb  text_encoder/model-00001-of-00003.safetensors
e6a24ea164404a01ad2800dbae4e1a13  text_encoder/model-00002-of-00003.safetensors
09e190ed15ff14795b6277e023cfcb2d  text_encoder/model-00003-of-00003.safetensors
589f5395156900f49d617aee8a8d8708  text_encoder/model.safetensors.index.json
6423133b9cc1a2077b57822c30c211aa  tokenizer/tokenizer.json
b06e103ac555ec4b51266078b518c0f0  tokenizer/tokenizer_config.json
baed87136fe5f848e24b072f99856cc3  transformer/config.json
54889d0dd179b4fa2fd7bd0e487d856e  transformer/diffusion_pytorch_model-00001-of-00003.safetensors
fe81e804658d345323512c63224b0604  transformer/diffusion_pytorch_model-00002-of-00003.safetensors
4e074e09129f98ad840414951f122feb  transformer/diffusion_pytorch_model-00003-of-00003.safetensors
76d788eb0d42c59cc8f8ec007db639aa  transformer/diffusion_pytorch_model.safetensors.index.json
ba9e2980c8630b4abccc643bc9f4a542  vae/config.json
6f83de55cb720c7fae051b14528577bf  vae/diffusion_pytorch_model.safetensors


================================================
FILE: src/config/model.py
================================================
"""Model configuration constants for Z-Image."""

ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32

ROPE_THETA = 256.0
ROPE_AXES_DIMS = [32, 48, 48]
ROPE_AXES_LENS = [1536, 512, 512]

FREQUENCY_EMBEDDING_SIZE = 256
MAX_PERIOD = 10000

BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15

DEFAULT_VAE_SCALE_FACTOR = 8
DEFAULT_VAE_IN_CHANNELS = 3
DEFAULT_VAE_OUT_CHANNELS = 3
DEFAULT_VAE_LATENT_CHANNELS = 4
DEFAULT_VAE_NORM_NUM_GROUPS = 32
DEFAULT_VAE_SCALING_FACTOR = 0.18215

DEFAULT_TRANSFORMER_PATCH_SIZE = (2,)
DEFAULT_TRANSFORMER_F_PATCH_SIZE = (1,)
DEFAULT_TRANSFORMER_IN_CHANNELS = 16
DEFAULT_TRANSFORMER_DIM = 3840
DEFAULT_TRANSFORMER_N_LAYERS = 30
DEFAULT_TRANSFORMER_N_REFINER_LAYERS = 2
DEFAULT_TRANSFORMER_N_HEADS = 30
DEFAULT_TRANSFORMER_N_KV_HEADS = 30
DEFAULT_TRANSFORMER_NORM_EPS = 1e-5
DEFAULT_TRANSFORMER_QK_NORM = True
DEFAULT_TRANSFORMER_CAP_FEAT_DIM = 2560
DEFAULT_TRANSFORMER_T_SCALE = 1000.0

DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS = 1000
DEFAULT_SCHEDULER_SHIFT = 3.0
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING = False

DEFAULT_LOAD_DEVICE = "cuda"
DEFAULT_LOAD_DTYPE_STR = "bfloat16"

BYTES_PER_GB = 2**30


================================================
FILE: src/tools/__init__.py
================================================
"""Tools for Z-Image model management."""

from .generate_manifest import compute_md5, get_essential_files

__all__ = [
    "compute_md5",
    "get_essential_files",
]



================================================
FILE: src/tools/generate_manifest.py
================================================
#!/usr/bin/env python3
"""Generate manifest file with MD5 checksums for model weights.

Usage:
    python -m tools.generate_manifest ckpts/Z-Image-Turbo
    python -m tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums  # Only list files
"""

import argparse
import hashlib
from pathlib import Path
from typing import List


def compute_md5(file_path: Path, chunk_size: int = 8192) -> str:
    """Compute MD5 hash of a file."""
    md5_hash = hashlib.md5()
    with open(file_path, "rb") as f:
        while chunk := f.read(chunk_size):
            md5_hash.update(chunk)
    return md5_hash.hexdigest()


def get_essential_files(model_dir: Path) -> List[Path]:
    """Get list of essential model files."""
    essential_patterns = [
        "model_index.json",
        "transformer/config.json",
        "transformer/*.safetensors*",
        "vae/config.json",
        "vae/*.safetensors",
        "text_encoder/config.json",
        "text_encoder/*.safetensors*",
        "tokenizer/tokenizer.json",
        "tokenizer/tokenizer_config.json",
        "scheduler/scheduler_config.json",
    ]
    
    files = []
    for pattern in essential_patterns:
        if "*" in pattern:
            files.extend(model_dir.glob(pattern))
        else:
            file_path = model_dir / pattern
            if file_path.exists():
                files.append(file_path)
    
    return sorted(files)


def main():
    parser = argparse.ArgumentParser(description="Generate manifest file for model weights")
    parser.add_argument("model_dir", type=str, help="Path to model directory")
    parser.add_argument("--output", "-o", type=str, default=None,
                       help="Output manifest file path (default: auto-detect to config/manifests/)")
    parser.add_argument("--no-checksums", action="store_true",
                       help="Only list files without computing checksums")
    parser.add_argument("--verbose", "-v", action="store_true",
                       help="Print progress")
    
    args = parser.parse_args()
    
    model_dir = Path(args.model_dir)
    if not model_dir.exists():
        print(f"Error: Model directory not found: {model_dir}")
        return 1
    
    # Determine output path
    if args.output:
        output_file = Path(args.output)
    else:
        # Auto-detect: save to config/manifests/{model-name}.txt
        model_name = model_dir.name.lower()  # e.g., "Z-Image-Turbo" -> "z-image-turbo"
        script_dir = Path(__file__).parent
        config_dir = script_dir.parent / "config" / "manifests"
        config_dir.mkdir(parents=True, exist_ok=True)
        output_file = config_dir / f"{model_name}.txt"
    
    # Get essential files
    files = get_essential_files(model_dir)
    
    if not files:
        print(f"Warning: No essential files found in {model_dir}")
        return 1
    
    print(f"Found {len(files)} essential files")
    
    # Generate manifest
    with open(output_file, "w", encoding="utf-8") as f:
        f.write("# Z-Image Model Manifest\n")
        if args.no_checksums:
            f.write("# Format: <filepath>\n")
        else:
            f.write("# Format: <md5hash>  <filepath>\n")
        f.write("# Generated automatically - DO NOT edit manually\n\n")
        
        for file_path in files:
            rel_path = file_path.relative_to(model_dir)
            
            if args.no_checksums:
                f.write(f"{rel_path}\n")
                if args.verbose:
                    print(f"  {rel_path}")
            else:
                if args.verbose:
                    print(f"Computing MD5 for {rel_path}...", end=" ", flush=True)
                
                try:
                    md5_hash = compute_md5(file_path)
                    f.write(f"{md5_hash}  {rel_path}\n")
                    if args.verbose:
                        print(f"✓ {md5_hash}")
                except Exception as e:
                    print(f"✗ Error: {e}")
                    continue
    
    print(f"\n✓ Manifest saved to: {output_file}")
    print(f"  Total files: {len(files)}")
    if not args.no_checksums:
        print(f"  With MD5 checksums for integrity verification")
    
    return 0


if __name__ == "__main__":
    exit(main())



================================================
FILE: src/utils/__init__.py
================================================
"""Utilities for Z-Image."""

from .attention import AttentionBackend, dispatch_attention, set_attention_backend
from .helpers import format_bytes, print_memory_stats, ensure_model_weights
from .loader import load_from_local_dir

__all__ = [
    "load_from_local_dir",
    "format_bytes",
    "print_memory_stats",
    "ensure_model_weights",
    "AttentionBackend",
    "set_attention_backend",
    "dispatch_attention",
]


================================================
FILE: src/utils/attention.py
================================================
"""Attention backend utilities for Z-Image."""

# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py
from enum import Enum
import functools
import inspect
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F

from .import_utils import is_flash_attn_3_available, is_flash_attn_available, is_torch_version

_CAN_USE_FLASH_ATTN_2 = is_flash_attn_available()
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()

# MPS Flash Attention (Apple Silicon)
try:
    import mps_flash_attn
    _CAN_USE_MPS_FLASH = mps_flash_attn.is_available()
except ImportError:
    _CAN_USE_MPS_FLASH = False
    mps_flash_attn = None
_TORCH_VERSION_CHECK = is_torch_version(">=", "2.5.0")  # have enable_gqa func call in SPDA

if not _TORCH_VERSION_CHECK:
    raise RuntimeError("PyTorch version must be >= 2.5.0 to use this backend.")
else:
    print("PyTorch version is >= 2.5.0, check pass.")

if _CAN_USE_FLASH_ATTN_2:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
    flash_attn_func = None
    flash_attn_varlen_func = None

if _CAN_USE_FLASH_ATTN_3:
    from flash_attn_interface import (
        flash_attn_func as flash_attn_3_func,
        flash_attn_varlen_func as flash_attn_3_varlen_func,
    )

    _flash_attn_3_sig = inspect.signature(flash_attn_3_func)
    _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = "return_attn_probs" in _flash_attn_3_sig.parameters
else:
    flash_attn_3_func = None
    flash_attn_3_varlen_func = None
    _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = False


class AttentionBackend(str, Enum):
    """Supported attention backends."""

    # Flash Attention
    FLASH = "flash"
    FLASH_VARLEN = "flash_varlen"
    FLASH_3 = "_flash_3"
    FLASH_VARLEN_3 = "_flash_varlen_3"
    # MPS Flash Attention (Apple Silicon)
    MPS_FLASH = "mps_flash"
    # PyTorch Native Backends
    NATIVE = "native"
    NATIVE_FLASH = "_native_flash"
    NATIVE_MATH = "_native_math"

    @classmethod
    def print_available_backends(cls):
        available_backends = [backend.value for backend in cls.__members__.values()]
        print(f"Available attention backends list: {available_backends}")


# Registry for attention implementations
_ATTENTION_BACKENDS: Dict[str, Callable] = {}
_ATTENTION_CONSTRAINTS: Dict[str, List[Callable]] = {}


def register_backend(name: str, constraints: Optional[List[Callable]] = None):
    def decorator(func):
        _ATTENTION_BACKENDS[name] = func
        _ATTENTION_CONSTRAINTS[name] = constraints or []
        return func

    return decorator


# --- Checks ---
def _check_device_cuda(query: torch.Tensor, **kwargs) -> None:
    if query.device.type != "cuda":
        raise ValueError("Query must be on a CUDA device.")


def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:
    if query.dtype not in (torch.bfloat16, torch.float16):
        raise ValueError("Query must be either bfloat16 or float16.")


def _check_device_mps(query: torch.Tensor, **kwargs) -> None:
    if query.device.type != "mps":
        raise ValueError("Query must be on MPS device.")


def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
    if attn_mask is None:
        return None

    if attn_mask.ndim == 2:
        attn_mask = attn_mask[:, None, None, :]

    # Convert bool mask to float additive mask
    if attn_mask.dtype == torch.bool:
        # NOTE: We skip checking for all-True mask (torch.all) to avoid graph breaks in torch.compile
        new_mask = torch.zeros_like(attn_mask, dtype=dtype)
        new_mask.masked_fill_(~attn_mask, float("-inf"))
        return new_mask

    return attn_mask


def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
    """Normalize an attention mask to shape [batch_size, seq_len_k] (bool)."""
    if attn_mask.dtype != torch.bool:
        # Try to convert float mask back to bool if possible, or assume it's float mask
        # For varlen flash attn, we strictly need bool mask indicating valid tokens
        if torch.is_floating_point(attn_mask):
            return attn_mask > -1  # Assuming -inf is masked
        # raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")

    if attn_mask.ndim == 1:
        attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
    elif attn_mask.ndim == 2:
        if attn_mask.size(0) not in [1, batch_size]:
            attn_mask = attn_mask.expand(batch_size, seq_len_k)
    elif attn_mask.ndim == 3:
        attn_mask = attn_mask.any(dim=1)
        attn_mask = attn_mask.expand(batch_size, seq_len_k)
    elif attn_mask.ndim == 4:
        attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k)
        attn_mask = attn_mask.any(dim=(1, 2))

    if attn_mask.shape != (batch_size, seq_len_k):
        # Fallback reshape
        return attn_mask.view(batch_size, seq_len_k)

    return attn_mask


@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_varlen_without_mask(
    batch_size: int,
    seq_len_q: int,
    seq_len_kv: int,
    device: Optional[torch.device] = None,
):
    # Optimized to avoid Inductor "pointless_cumsum_replacement" crash and remove graph breaks
    seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
    seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)

    cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
    cu_seqlens_k = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_kv

    return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (seq_len_q, seq_len_kv)


def _prepare_for_flash_attn_varlen_with_mask(
    batch_size: int,
    seq_len_q: int,
    attn_mask: torch.Tensor,
    device: Optional[torch.device] = None,
):
    seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
    seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
    # Use arange for Q to avoid Inductor crash
    cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q

    cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
    cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)

    max_seqlen_q = seq_len_q
    max_seqlen_k = attn_mask.shape[1]  # not max().item(), static shape to avoid graph break

    return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)


def _prepare_for_flash_attn_varlen(
    batch_size: int,
    seq_len_q: int,
    seq_len_kv: int,
    attn_mask: Optional[torch.Tensor] = None,
    device: Optional[torch.device] = None,
) -> None:
    if attn_mask is None:
        return _prepare_for_flash_attn_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
    return _prepare_for_flash_attn_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)


@register_backend(AttentionBackend.FLASH, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    if not _CAN_USE_FLASH_ATTN_2:
        raise RuntimeError(
            f"Flash Attention backend '{AttentionBackend.FLASH}' is not usable because of missing package."
        )

    out = flash_attn_func(
        q=query,
        k=key,
        v=value,
        dropout_p=dropout_p,
        softmax_scale=scale,
        causal=is_causal,
    )
    return out


@register_backend(AttentionBackend.FLASH_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_varlen_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    if not _CAN_USE_FLASH_ATTN_2:
        raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN}' requires flash-attn.")

    batch_size, seq_len_q, _, _ = query.shape
    _, seq_len_kv, _, _ = key.shape

    if attn_mask is not None:
        attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

    (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
        batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
    )

    query_packed = query.flatten(0, 1)

    if attn_mask is not None:
        key_valid = []
        value_valid = []
        for b in range(batch_size):
            valid_len = seqlens_k[b]
            key_valid.append(key[b, :valid_len])
            value_valid.append(value[b, :valid_len])
        key_packed = torch.cat(key_valid, dim=0)
        value_packed = torch.cat(value_valid, dim=0)
    else:
        key_packed = key.flatten(0, 1)
        value_packed = value.flatten(0, 1)

    out = flash_attn_varlen_func(
        q=query_packed,
        k=key_packed,
        v=value_packed,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_k=cu_seqlens_k,
        max_seqlen_q=max_seqlen_q,
        max_seqlen_k=max_seqlen_k,
        dropout_p=dropout_p,
        softmax_scale=scale,
        causal=is_causal,
    )
    out = out.unflatten(0, (batch_size, -1))
    return out


@register_backend(AttentionBackend.FLASH_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_attention_3(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,  # Unused in simple FA3 func
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    if not _CAN_USE_FLASH_ATTN_3:
        raise RuntimeError(f"Backend '{AttentionBackend.FLASH_3}' requires Flash Attention 3 beta.")

    kwargs = {
        "q": query,
        "k": key,
        "v": value,
        "softmax_scale": scale,
        "causal": is_causal,
    }

    if _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS:
        kwargs["return_attn_probs"] = False

    out = flash_attn_3_func(**kwargs)

    if isinstance(out, tuple):
        out = out[0]

    return out


@register_backend(AttentionBackend.FLASH_VARLEN_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_varlen_attention_3(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    if not _CAN_USE_FLASH_ATTN_3:
        raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN_3}' requires Flash Attention 3 beta.")

    batch_size, seq_len_q, _, _ = query.shape
    _, seq_len_kv, _, _ = key.shape

    if attn_mask is not None:
        attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

    (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
        batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
    )

    query_packed = query.flatten(0, 1)

    if attn_mask is not None:
        key_valid = []
        value_valid = []
        for b in range(batch_size):
            valid_len = seqlens_k[b]
            key_valid.append(key[b, :valid_len])
            value_valid.append(value[b, :valid_len])
        key_packed = torch.cat(key_valid, dim=0)
        value_packed = torch.cat(value_valid, dim=0)
    else:
        key_packed = key.flatten(0, 1)
        value_packed = value.flatten(0, 1)

    kwargs = {
        "q": query_packed,
        "k": key_packed,
        "v": value_packed,
        "cu_seqlens_q": cu_seqlens_q,
        "cu_seqlens_k": cu_seqlens_k,
        "max_seqlen_q": max_seqlen_q,
        "max_seqlen_k": max_seqlen_k,
        "softmax_scale": scale,
        "causal": is_causal,
    }

    supports_return_probs = "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters

    if supports_return_probs:
        kwargs["return_attn_probs"] = False

    out = flash_attn_3_varlen_func(**kwargs)

    if isinstance(out, tuple):
        out = out[0]

    out = out.unflatten(0, (batch_size, -1))
    return out


@register_backend(AttentionBackend.MPS_FLASH, constraints=[_check_device_mps, _check_qkv_dtype_bf16_or_fp16])
def _mps_flash_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    """MPS Flash Attention for Apple Silicon (M1/M2/M3/M4)."""
    if not _CAN_USE_MPS_FLASH:
        raise RuntimeError(
            f"MPS Flash Attention backend '{AttentionBackend.MPS_FLASH}' requires mps-flash-attn package. "
            "Install with: pip install mps-flash-attn"
        )

    # Convert from (B, S, H, D) to (B, H, S, D) for mps-flash-attn
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    # Convert mask to MFA format (bool, True = masked)
    mfa_mask = None
    if attn_mask is not None:
        mfa_mask = mps_flash_attn.convert_mask(_process_mask(attn_mask, query.dtype))

    out = mps_flash_attn.flash_attention(
        query, key, value,
        is_causal=is_causal,
        scale=scale,
        attn_mask=mfa_mask,
    )

    # Convert back to (B, S, H, D)
    return out.transpose(1, 2).contiguous()


def _native_attention_wrapper(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
    backend_kernel=None,
) -> torch.Tensor:

    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)
    attn_mask = _process_mask(attn_mask, query.dtype)

    if backend_kernel is not None:
        with torch.nn.attention.sdpa_kernel(backend_kernel):
            out = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
            )
    else:
        out = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
        )

    return out.transpose(1, 2).contiguous()


@register_backend(AttentionBackend.NATIVE_FLASH)
def _native_flash_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
) -> torch.Tensor:
    return _native_attention_wrapper(
        query,
        key,
        value,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
        backend_kernel=torch.nn.attention.SDPBackend.FLASH_ATTENTION,
    )


@register_backend(AttentionBackend.NATIVE_MATH)
def _math_attention(*args, **kwargs):
    return _native_attention_wrapper(*args, **kwargs, backend_kernel=torch.nn.attention.SDPBackend.MATH)


@register_backend(AttentionBackend.NATIVE)
def _native_attention(*args, **kwargs):
    return _native_attention_wrapper(*args, **kwargs, backend_kernel=None)


def dispatch_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
    backend: Union[str, AttentionBackend, None] = None,
) -> torch.Tensor:

    if isinstance(backend, AttentionBackend):
        backend = backend.value
    elif backend is None:
        backend = AttentionBackend.NATIVE
    else:
        backend = str(backend)

    # Explicit dispatch to avoid dynamo guard issues on global dict
    if backend == AttentionBackend.FLASH:
        return _flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.FLASH_VARLEN:
        return _flash_varlen_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.FLASH_3:
        return _flash_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.FLASH_VARLEN_3:
        return _flash_varlen_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.MPS_FLASH:
        return _mps_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.NATIVE_FLASH:
        return _native_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    elif backend == AttentionBackend.NATIVE_MATH:
        return _math_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    else:
        return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)


def set_attention_backend(backend: Union[str, AttentionBackend, None]):
    try:
        from zimage.transformer import ZImageAttention

        if backend is not None:
            backend = str(backend)
        ZImageAttention._attention_backend = backend
    except ImportError:
        pass


================================================
FILE: src/utils/helpers.py
================================================
"""Helper utilities for Z-Image."""

import hashlib
import json
from pathlib import Path
from typing import Optional, List, Tuple, Dict

from loguru import logger
import torch

from config import BYTES_PER_GB


def format_bytes(size: float) -> str:
    """
    Format bytes to GB string.

    Args:
        size: Size in bytes

    Returns:
        Formatted string in GB
    """
    n = size / BYTES_PER_GB
    return f"{n:.2f} GB"


def print_memory_stats(stage: str) -> None:
    """
    Print CUDA memory statistics.

    Args:
        stage: Description of current stage
    """
    if not torch.cuda.is_available():
        logger.warning("CUDA not available, skipping memory stats")
        return

    torch.cuda.synchronize()
    allocated = torch.cuda.max_memory_allocated()
    reserved = torch.cuda.max_memory_reserved()
    current_allocated = torch.cuda.memory_allocated()
    current_reserved = torch.cuda.memory_reserved()

    logger.info(f"[{stage}] Memory Stats:")
    logger.info(f"  Current Allocated: {format_bytes(current_allocated)}")
    logger.info(f"  Current Reserved:  {format_bytes(current_reserved)}")
    logger.info(f"  Peak Allocated:    {format_bytes(allocated)}")
    logger.info(f"  Peak Reserved:     {format_bytes(reserved)}")


def compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:
    """Compute MD5 hash of a file."""
    md5_hash = hashlib.md5()
    with open(file_path, "rb") as f:
        while chunk := f.read(chunk_size):
            md5_hash.update(chunk)
    return md5_hash.hexdigest()


def load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:
    """Load manifest file. Returns dict mapping file paths to MD5 hashes (or None)."""
    manifest = {}
    if not manifest_file.exists():
        return manifest
    
    with open(manifest_file, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            # Skip empty lines and comments
            if not line or line.startswith("#"):
                continue
            
            parts = line.split()
            
            if len(parts) == 1:
                # Only file path, no checksum
                file_path = parts[0]
                manifest[file_path] = None
            elif len(parts) == 2:
                # File path with checksum
                if len(parts[0]) == 32 and all(c in '0123456789abcdef' for c in parts[0].lower()):
                    md5_hash, file_path = parts
                else:
                    file_path, md5_hash = parts
                manifest[file_path] = md5_hash
            else:
                logger.warning(f"Invalid manifest format at line {line_num}: {line}")
                continue
    
    return manifest


def verify_file_integrity(
    base_dir: Path, 
    manifest: Dict[str, Optional[str]],
    verify_checksums: bool = True
) -> Tuple[bool, List[str], List[str]]:
    """
    Verify file integrity using a manifest.
    
    Args:
        base_dir: Base directory for relative file paths
        manifest: Dictionary of relative paths to MD5 hashes (None if no hash provided)
        verify_checksums: If True, verify MD5 checksums when available; if False, only check existence
        
    Returns:
        Tuple of (all_valid: bool, missing_files: List[str], corrupted_files: List[str])
    """
    missing = []
    corrupted = []
    
    for rel_path, expected_md5 in manifest.items():
        file_path = base_dir / rel_path
        
        if not file_path.exists():
            missing.append(rel_path)
            continue
        
        # Only verify checksum if requested AND hash is available
        if verify_checksums and expected_md5 is not None:
            try:
                actual_md5 = compute_file_md5(file_path)
                if actual_md5 != expected_md5:
                    corrupted.append(rel_path)
                    logger.debug(f"Checksum mismatch for {rel_path}: expected {expected_md5}, got {actual_md5}")
            except Exception as e:
                logger.error(f"Failed to compute checksum for {rel_path}: {e}")
                corrupted.append(rel_path)
    
    all_valid = len(missing) == 0 and len(corrupted) == 0
    return all_valid, missing, corrupted


def ensure_model_weights(
    model_path: str, 
    repo_id: str = "Tongyi-MAI/Z-Image-Turbo",
    verify: bool = False,
    manifest_name: Optional[str] = None
) -> Path:
    """
    Ensure model weights exist and optionally verify integrity.
    
    Args:
        model_path: Path to model directory
        repo_id: HuggingFace repo ID for download
        verify: If True, verify MD5 checksums; if False, only check existence
        manifest_name: Manifest file name in src/config/manifests/ (auto-detect if None)
        
    Returns:
        Path to validated model directory
    """
    from huggingface_hub import snapshot_download
    
    target_dir = Path(model_path)
    
    # Determine manifest path
    if manifest_name:
        # Explicitly specified manifest from config/manifests/
        manifest_path = Path(__file__).parent.parent / "config" / "manifests" / manifest_name
    else:
        # Auto-detect
        model_name = target_dir.name.lower()  # e.g., "Z-Image-Turbo" -> "z-image-turbo"
        config_manifest = Path(__file__).parent.parent / "config" / "manifests" / f"{model_name}.txt"
        
        if config_manifest.exists():
            manifest_path = config_manifest
        else:
            # Fallback
            manifest_path = target_dir / "manifest.txt"
    
    manifest = load_manifest(manifest_path)
    
    if not manifest:
        logger.warning(f"Manifest file not found: {manifest_path}")
        logger.warning("Skipping file verification (assuming model exists)")
        if target_dir.exists():
            logger.info(f"✓ Model directory exists: {target_dir}")
            return target_dir
        else:
            logger.warning(f"Model directory not found: {target_dir}")
            missing_files = ["entire model directory"]
            corrupted_files = []
    else:
        # Count files with checksums
        files_with_checksums = sum(1 for v in manifest.values() if v is not None)
        
        if verify and files_with_checksums == 0:
            logger.info(f"Verify requested but no checksums in manifest, only checking existence")
        elif verify and files_with_checksums > 0:
            logger.info(f"Verifying {files_with_checksums} file(s) with MD5 checksums...")
        
        # Verify files
        all_valid, missing_files, corrupted_files = verify_file_integrity(
            target_dir, manifest, verify_checksums=verify
        )
        
        if all_valid:
            if verify and files_with_checksums > 0:
                logger.success(f"✓ All files verified with MD5 checksums in {target_dir}")
            else:
                logger.info(f"✓ All {len(manifest)} required files exist in {target_dir}")
            return target_dir
    
    # Report missing and corrupted files
    if missing_files:
        logger.warning(f"Missing {len(missing_files)} file(s):")
        for f in missing_files[:10]:
            logger.warning(f"  - {f}")
        if len(missing_files) > 10:
            logger.warning(f"  ... and {len(missing_files) - 10} more")
    
    if corrupted_files:
        logger.error(f"Corrupted {len(corrupted_files)} file(s) (checksum mismatch):")
        for f in corrupted_files[:10]:
            logger.error(f"  - {f}")
        if len(corrupted_files) > 10:
            logger.error(f"  ... and {len(corrupted_files) - 10} more")
    
    # Download model weights
    logger.info(f"\nAttempting to download from {repo_id}...")
    try:
        target_dir.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=repo_id,
            local_dir=str(target_dir),
            local_dir_use_symlinks=False,
            resume_download=True,
        )
        logger.success("✓ Download completed")
    except Exception as e:
        logger.error(f"✗ Download failed: {e}")
        logger.info(
            f"\nIf you are offline, please manually download from:\n"
            f"  https://huggingface.co/{repo_id}\n"
            f"and place in: {target_dir.absolute()}"
        )
        raise RuntimeError(f"Failed to download model weights: {e}")
    
    # Verify after download
    if manifest:
        all_valid, missing_after, corrupted_after = verify_file_integrity(
            target_dir, manifest, verify_checksums=verify
        )
        
        if not all_valid:
            error_msg = []
            if missing_after:
                error_msg.append(f"Still missing {len(missing_after)} file(s)")
            if corrupted_after:
                error_msg.append(f"Still corrupted {len(corrupted_after)} file(s)")
            
            raise FileNotFoundError(
                f"After download: {', '.join(error_msg)}\n"
                f"Please verify the download or manually place files in:\n"
                f"  {target_dir.absolute()}"
            )
    
    logger.success("✓ All model weights validated successfully")
    return target_dir


================================================
FILE: src/utils/import_utils.py
================================================
import importlib.util

import torch


def is_flash_attn_available():
    return importlib.util.find_spec("flash_attn") is not None


def is_flash_attn_3_available():
    return importlib.util.find_spec("flash_attn_interface") is not None


def is_torch_version(operator: str, version: str):
    from packaging import version as pversion

    torch_version = pversion.parse(torch.__version__)
    target_version = pversion.parse(version)

    # print(f"torch_version: {torch_version}, target: torch{operator}{target_version}")
    if operator == ">":
        return torch_version > target_version
    elif operator == ">=":
        return torch_version >= target_version
    elif operator == "==":
        return torch_version == target_version
    elif operator == "<=":
        return torch_version <= target_version
    elif operator == "<":
        return torch_version < target_version
    return False


================================================
FILE: src/utils/loader.py
================================================
"""Model loading utilities for Z-Image components."""

import json
import os
from pathlib import Path
import sys
from typing import Optional, Union

from loguru import logger
from safetensors.torch import load_file
import torch
from transformers import AutoModel, AutoTokenizer

from config import (
    DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
    DEFAULT_SCHEDULER_SHIFT,
    DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
    DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
    DEFAULT_TRANSFORMER_DIM,
    DEFAULT_TRANSFORMER_F_PATCH_SIZE,
    DEFAULT_TRANSFORMER_IN_CHANNELS,
    DEFAULT_TRANSFORMER_N_HEADS,
    DEFAULT_TRANSFORMER_N_KV_HEADS,
    DEFAULT_TRANSFORMER_N_LAYERS,
    DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
    DEFAULT_TRANSFORMER_NORM_EPS,
    DEFAULT_TRANSFORMER_PATCH_SIZE,
    DEFAULT_TRANSFORMER_QK_NORM,
    DEFAULT_TRANSFORMER_T_SCALE,
    DEFAULT_VAE_IN_CHANNELS,
    DEFAULT_VAE_LATENT_CHANNELS,
    DEFAULT_VAE_NORM_NUM_GROUPS,
    DEFAULT_VAE_OUT_CHANNELS,
    DEFAULT_VAE_SCALING_FACTOR,
    ROPE_AXES_DIMS,
    ROPE_AXES_LENS,
    ROPE_THETA,
)
from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL
from zimage.scheduler import FlowMatchEulerDiscreteScheduler

DIFFUSERS_AVAILABLE = False


def load_config(config_path: str) -> dict:
    with open(config_path, "r") as f:
        return json.load(f)


def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict:
    """Load sharded safetensors from a directory."""
    weight_dir = Path(weight_dir)
    index_files = list(weight_dir.glob("*.safetensors.index.json"))

    state_dict = {}
    if index_files:
        # Load sharded weights
        with open(index_files[0], "r") as f:
            index = json.load(f)
        weight_map = index.get("weight_map", {})
        shard_files = set(weight_map.values())
        for shard_file in shard_files:
            shard_path = weight_dir / shard_file
            shard_state = load_file(str(shard_path), device=str(device))
            state_dict.update(shard_state)
    else:
        # Load single safetensors file
        safetensors_files = list(weight_dir.glob("*.safetensors"))
        if not safetensors_files:
            raise FileNotFoundError(f"No safetensors files found in {weight_dir}")
        state_dict = load_file(str(safetensors_files[0]), device=str(device))

    # Cast to target dtype if specified
    if dtype is not None:
        state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}

    return state_dict


def load_from_local_dir(
    model_dir: Union[str, Path],
    device: str = "cuda",
    dtype: torch.dtype = torch.bfloat16,
    verbose: bool = False,
    compile: bool = False,
) -> dict:
    """
    Load all Z-Image components from local directory.

    Args:
        model_dir: Path to model directory
        device: Device to load models on
        dtype: Data type for model weights
        verbose: Whether to display loading logs
        compile: Whether to compile transformer and vae with torch.compile

    Returns:
        Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler
    """
    model_dir = Path(model_dir)

    sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src"))
    from zimage.transformer import ZImageTransformer2DModel

    if verbose:
        logger.info(f"Loading Z-Image from: {model_dir}")

    # DiT
    if verbose:
        logger.info("Loading DiT...")
    transformer_dir = model_dir / "transformer"
    config = load_config(str(transformer_dir / "config.json"))

    with torch.device("meta"):
        transformer = ZImageTransformer2DModel(
            all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)),
            all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),
            in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS),
            dim=config.get("dim", DEFAULT_TRANSFORMER_DIM),
            n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS),
            n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),
            n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS),
            n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS),
            norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS),
            qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM),
            cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),
            rope_theta=config.get("rope_theta", ROPE_THETA),
            t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE),
            axes_dims=config.get("axes_dims", ROPE_AXES_DIMS),
            axes_lens=config.get("axes_lens", ROPE_AXES_LENS),
        ).to(dtype)

    # DiT (weights to CPU then move to GPU to optimize memory)
    state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype)
    transformer.load_state_dict(state_dict, strict=False, assign=True)
    del state_dict

    if verbose:
        logger.info("Moving DiT to GPU...")
    transformer = transformer.to(device)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    transformer.eval()

    # VAE
    if verbose:
        logger.info("Loading VAE...")
    vae_dir = model_dir / "vae"
    vae_config = load_config(str(vae_dir / "config.json"))

    vae = LocalAutoencoderKL(
        in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS),
        out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS),
        down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))),
        up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))),
        block_out_channels=tuple(vae_config.get("block_out_channels", (64,))),
        layers_per_block=vae_config.get("layers_per_block", 1),
        latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS),
        norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS),
        scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR),
        shift_factor=vae_config.get("shift_factor", None),
        use_quant_conv=vae_config.get("use_quant_conv", True),
        use_post_quant_conv=vae_config.get("use_post_quant_conv", True),
        mid_block_add_attention=vae_config.get("mid_block_add_attention", True),
    )

    # VAE (fp32 for better precision)
    vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu")
    vae.load_state_dict(vae_state_dict, strict=False)
    del vae_state_dict
    vae.to(device=device, dtype=torch.float32)
    vae.eval()
    torch.cuda.empty_cache()

    # Text Encoder
    if verbose:
        logger.info("Loading Text Encoder...")
    text_encoder_dir = model_dir / "text_encoder"
    text_encoder = AutoModel.from_pretrained(
        str(text_encoder_dir),
        # torch_dtype=dtype, # some version use this
        dtype=dtype,
        trust_remote_code=True,
    )
    text_encoder.to(device)
    text_encoder.eval()

    # Tokenizer
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    if verbose:
        logger.info("Loading Tokenizer...")
    tokenizer_dir = model_dir / "tokenizer"
    tokenizer = AutoTokenizer.from_pretrained(
        str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),
        trust_remote_code=True,
    )

    # Scheduler
    if verbose:
        logger.info("Loading Scheduler...")
    scheduler_dir = model_dir / "scheduler"
    scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json"))
    scheduler = FlowMatchEulerDiscreteScheduler(
        num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),
        shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT),
        use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),
    )

    if compile:
        if verbose:
            logger.info("Compiling DiT and VAE...")
        transformer = torch.compile(transformer)
        vae = torch.compile(vae)

    if verbose:
        logger.success("All components loaded successfully")

    return {
        "transformer": transformer,
        "vae": vae,
        "text_encoder": text_encoder,
        "tokenizer": tokenizer,
        "scheduler": scheduler,
    }


================================================
FILE: src/zimage/__init__.py
================================================
"""Z-Image PyTorch Native Implementation."""

from .pipeline import generate
from .transformer import ZImageTransformer2DModel

__all__ = [
    "ZImageTransformer2DModel",
    "generate",
]


================================================
FILE: src/zimage/autoencoder.py
================================================
"""AutoencoderKL implementation compatible with diffusers weights."""

# Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn


@dataclass
class AutoencoderKLOutput:
    sample: torch.Tensor


class AutoencoderConfig:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def get(self, key, default=None):
        return self.__dict__.get(key, default)

    def __getattr__(self, name):
        return self.__dict__.get(name)


def swish(x):
    return x * torch.sigmoid(x)


class ResnetBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):
        super().__init__()
        out_channels = out_channels or in_channels
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.nonlinearity = swish

        if self.in_channels != self.out_channels:
            self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.conv_shortcut = None

    def forward(self, input_tensor, temb=None):
        hidden_states = input_tensor
        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = (input_tensor + hidden_states) / 1.0
        return output_tensor


class Attention(nn.Module):
    def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):
        super().__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.to_q = nn.Linear(in_channels, in_channels)
        self.to_k = nn.Linear(in_channels, in_channels)
        self.to_v = nn.Linear(in_channels, in_channels)
        self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])

    def forward(self, hidden_states):
        b, c, h, w = hidden_states.shape
        residual = hidden_states
        hidden_states = self.group_norm(hidden_states)
        hidden_states = hidden_states.view(b, c, -1).transpose(1, 2)  # (B, H*W, C)

        query = self.to_q(hidden_states)
        key = self.to_k(hidden_states)
        value = self.to_v(hidden_states)

        import torch.nn.functional as F

        hidden_states = F.scaled_dot_product_attention(query, key, value)

        hidden_states = self.to_out[0](hidden_states)
        hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)

        return residual + hidden_states


class Downsample2D(nn.Module):
    def __init__(self, channels, with_conv=True, out_channels=None, padding=1):
        super().__init__()
        out_channels = out_channels or channels
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)

    def forward(self, hidden_states):
        if self.with_conv:
            return self.conv(hidden_states)
        else:
            return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)


class Upsample2D(nn.Module):
    def __init__(self, channels, with_conv=True, out_channels=None):
        super().__init__()
        out_channels = out_channels or channels
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, hidden_states):
        hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            hidden_states = self.conv(hidden_states)
        return hidden_states


class DownEncoderBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):
        super().__init__()
        resnets = []
        for i in range(num_layers):
            in_c = in_channels if i == 0 else out_channels
            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
        self.resnets = nn.ModuleList(resnets)

        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]
            )
        else:
            self.downsamplers = None

    def forward(self, hidden_states):
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                pad = (0, 1, 0, 1)
                hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
                hidden_states = downsampler(hidden_states)

        return hidden_states


class UpDecoderBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):
        super().__init__()
        resnets = []
        for i in range(num_layers):
            in_c = in_channels if i == 0 else out_channels
            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
        self.resnets = nn.ModuleList(resnets)

        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None

    def forward(self, hidden_states):
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states)

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)

        return hidden_states


class UNetMidBlock2D(nn.Module):
    def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):
        super().__init__()
        self.resnets = nn.ModuleList(
            [
                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
            ]
        )
        self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])

    def forward(self, hidden_states):
        hidden_states = self.resnets[0](hidden_states)
        for attn in self.attentions:
            hidden_states = attn(hidden_states)
        hidden_states = self.resnets[1](hidden_states)
        return hidden_states


class Encoder(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        block_out_channels=(64,),
        layers_per_block=2,
        norm_num_groups=32,
        double_z=True,
    ):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)

        self.down_blocks = nn.ModuleList([])
        output_channel = block_out_channels[0]
        for i, block_out_channel in enumerate(block_out_channels):
            input_channel = output_channel
            output_channel = block_out_channel
            is_final_block = i == len(block_out_channels) - 1

            block = DownEncoderBlock2D(
                input_channel,
                output_channel,
                num_layers=layers_per_block,
                resnet_groups=norm_num_groups,
                add_downsample=not is_final_block,
            )
            self.down_blocks.append(block)

        self.mid_block = UNetMidBlock2D(
            block_out_channels[-1],
            resnet_groups=norm_num_groups,
        )

        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()

        conv_out_channels = 2 * out_channels if double_z else out_channels
        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)
        for block in self.down_blocks:
            x = block(x)
        x = self.mid_block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        block_out_channels=(64,),
        layers_per_block=2,
        norm_num_groups=32,
    ):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)

        self.mid_block = UNetMidBlock2D(
            block_out_channels[-1],
            resnet_groups=norm_num_groups,
        )

        self.up_blocks = nn.ModuleList([])
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]

        for i, block_out_channel in enumerate(reversed_block_out_channels):
            input_channel = output_channel
            output_channel = block_out_channel
            is_final_block = i == len(block_out_channels) - 1
            block = UpDecoderBlock2D(
                input_channel,
                output_channel,
                num_layers=layers_per_block + 1,
                resnet_groups=norm_num_groups,
                add_upsample=not is_final_block,
            )
            self.up_blocks.append(block)

        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)
        x = self.mid_block(x)
        for block in self.up_blocks:
            x = block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)
        return x


class AutoencoderKL(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int] = (64,),
        layers_per_block: int = 1,
        act_fn: str = "silu",
        latent_channels: int = 4,
        norm_num_groups: int = 32,
        sample_size: int = 32,
        scaling_factor: float = 0.18215,
        shift_factor: Optional[float] = None,
        force_upcast: bool = True,
        use_quant_conv: bool = True,
        use_post_quant_conv: bool = True,
        mid_block_add_attention: bool = True,
        **kwargs,
    ):
        super().__init__()
        self.config = AutoencoderConfig(
            in_channels=in_channels,
            out_channels=out_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            latent_channels=latent_channels,
            scaling_factor=scaling_factor,
            shift_factor=shift_factor,
        )

        self.encoder = Encoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            double_z=True,
        )

        self.decoder = Decoder(
            in_channels=latent_channels,
            out_channels=out_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
        )

        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
        if self.post_quant_conv is not None:
            z = self.post_quant_conv(z)

        dec = self.decoder(z)

        if not return_dict:
            return (dec,)

        return AutoencoderKLOutput(sample=dec)


================================================
FILE: src/zimage/pipeline.py
================================================
"""Z-Image Pipeline."""

import inspect
from typing import List, Optional, Union

from loguru import logger
import torch

from config import (
    BASE_IMAGE_SEQ_LEN,
    BASE_SHIFT,
    DEFAULT_CFG_TRUNCATION,
    DEFAULT_GUIDANCE_SCALE,
    DEFAULT_HEIGHT,
    DEFAULT_INFERENCE_STEPS,
    DEFAULT_MAX_SEQUENCE_LENGTH,
    DEFAULT_WIDTH,
    MAX_IMAGE_SEQ_LEN,
    MAX_SHIFT,
)


def calculate_shift(
    image_seq_len,
    base_seq_len: int = BASE_IMAGE_SEQ_LEN,
    max_seq_len: int = MAX_IMAGE_SEQ_LEN,
    base_shift: float = BASE_SHIFT,
    max_shift: float = MAX_SHIFT,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(f"The scheduler does not support custom timestep schedules.")
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(f"The scheduler does not support custom sigmas schedules.")
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


@torch.no_grad()
def generate(
    transformer,
    vae,
    text_encoder,
    tokenizer,
    scheduler,
    prompt: Union[str, List[str]],
    height: int = DEFAULT_HEIGHT,
    width: int = DEFAULT_WIDTH,
    num_inference_steps: int = DEFAULT_INFERENCE_STEPS,
    guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: int = 1,
    generator: Optional[torch.Generator] = None,
    cfg_normalization: bool = False,
    cfg_truncation: float = DEFAULT_CFG_TRUNCATION,
    max_sequence_length: int = DEFAULT_MAX_SEQUENCE_LENGTH,
    output_type: str = "pil",
):
    device = next(transformer.parameters()).device

    if hasattr(vae, "config") and hasattr(vae.config, "block_out_channels"):
        vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
    else:
        vae_scale_factor = 8
    vae_scale = vae_scale_factor * 2

    if height % vae_scale != 0:
        raise ValueError(f"Height must be divisible by {vae_scale} (got {height}).")
    if width % vae_scale != 0:
        raise ValueError(f"Width must be divisible by {vae_scale} (got {width}).")

    if isinstance(prompt, str):
        batch_size = 1
        prompt = [prompt]
    else:
        batch_size = len(prompt)

    do_classifier_free_guidance = guidance_scale > 1.0
    logger.info(f"Generating image: {height}x{width}, steps={num_inference_steps}, cfg={guidance_scale}")

    formatted_prompts = []
    for p in prompt:
        messages = [{"role": "user", "content": p}]
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
        formatted_prompts.append(formatted_prompt)

    text_inputs = tokenizer(
        formatted_prompts,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        return_tensors="pt",
    )

    text_input_ids = text_inputs.input_ids.to(device)
    prompt_masks = text_inputs.attention_mask.to(device).bool()

    prompt_embeds = text_encoder(
        input_ids=text_input_ids,
        attention_mask=prompt_masks,
        output_hidden_states=True,
    ).hidden_states[-2]

    prompt_embeds_list = []
    for i in range(len(prompt_embeds)):
        prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])

    negative_prompt_embeds_list = []
    if do_classifier_free_guidance:
        if negative_prompt is None:
            negative_prompt = ["" for _ in prompt]
        elif isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]

        neg_formatted = []
        for p in negative_prompt:
            messages = [{"role": "user", "content": p}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=True,
            )
            neg_formatted.append(formatted_prompt)

        neg_inputs = tokenizer(
            neg_formatted,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            return_tensors="pt",
        )

        neg_input_ids = neg_inputs.input_ids.to(device)
        neg_masks = neg_inputs.attention_mask.to(device).bool()

        neg_embeds = text_encoder(
            input_ids=neg_input_ids,
            attention_mask=neg_masks,
            output_hidden_states=True,
        ).hidden_states[-2]

        for i in range(len(neg_embeds)):
            negative_prompt_embeds_list.append(neg_embeds[i][neg_masks[i]])

    if num_images_per_prompt > 1:
        prompt_embeds_list = [pe for pe in prompt_embeds_list for _ in range(num_images_per_prompt)]
        if do_classifier_free_guidance:
            negative_prompt_embeds_list = [
                npe for npe in negative_prompt_embeds_list for _ in range(num_images_per_prompt)
            ]

    height_latent = 2 * (int(height) // vae_scale)
    width_latent = 2 * (int(width) // vae_scale)
    shape = (batch_size * num_images_per_prompt, transformer.in_channels, height_latent, width_latent)

    latents = torch.randn(shape, generator=generator, device=device, dtype=torch.float32)

    actual_batch_size = batch_size * num_images_per_prompt
    image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)

    mu = calculate_shift(
        image_seq_len,
        scheduler.config.get("base_image_seq_len", 256),
        scheduler.config.get("max_image_seq_len", 4096),
        scheduler.config.get("base_shift", 0.5),
        scheduler.config.get("max_shift", 1.15),
    )
    scheduler.sigma_min = 0.0
    scheduler_kwargs = {"mu": mu}
    timesteps, num_inference_steps = retrieve_timesteps(
        scheduler,
        num_inference_steps,
        device,
        sigmas=None,
        **scheduler_kwargs,
    )

    logger.info(f"Sampling loop start: {num_inference_steps} steps")

    from tqdm import tqdm

    # Denoising loop with progress bar
    for i, t in enumerate(tqdm(timesteps, desc="Denoising", total=len(timesteps))):
        # If current t is 0 and it's the last step, skip computation
        if t == 0 and i == len(timesteps) - 1:
            logger.debug(f"Step {i+1}/{num_inference_steps} | t: {t.item():.2f} | Skipping last step")
            continue

        timestep = t.expand(latents.shape[0])
        timestep = (1000 - timestep) / 1000
        t_norm = timestep[0].item()

        current_guidance_scale = guidance_scale
        if do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1:
            if t_norm > cfg_truncation:
                current_guidance_scale = 0.0

        apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0

        if apply_cfg:
            latents_typed = latents.to(
                transformer.dtype if hasattr(transformer, "dtype") else next(transformer.parameters()).dtype
            )
            latent_model_input = latents_typed.repeat(2, 1, 1, 1)
            prompt_embeds_model_input = prompt_embeds_list + negative_prompt_embeds_list
            timestep_model_input = timestep.repeat(2)
        else:
            latent_model_input = latents.to(next(transformer.parameters()).dtype)
            prompt_embeds_model_input = prompt_embeds_list
            timestep_model_input = timestep

        latent_model_input = latent_model_input.unsqueeze(2)
        latent_model_input_list = list(latent_model_input.unbind(dim=0))

        model_out_list = transformer(
            latent_model_input_list,
            timestep_model_input,
            prompt_embeds_model_input,
        )[0]

        if apply_cfg:
            pos_out = model_out_list[:actual_batch_size]
            neg_out = model_out_list[actual_batch_size:]
            noise_pred = []
            for j in range(actual_batch_size):
                pos = pos_out[j].float()
                neg = neg_out[j].float()
                pred = pos + current_guidance_scale * (pos - neg)

                if cfg_normalization and float(cfg_normalization) > 0.0:
                    ori_pos_norm = torch.linalg.vector_norm(pos)
                    new_pos_norm = torch.linalg.vector_norm(pred)
                    max_new_norm = ori_pos_norm * float(cfg_normalization)
                    if new_pos_norm > max_new_norm:
                        pred = pred * (max_new_norm / new_pos_norm)
                noise_pred.append(pred)
            noise_pred = torch.stack(noise_pred, dim=0)
        else:
            noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)

        noise_pred = -noise_pred.squeeze(2)
        latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
        assert latents.dtype == torch.float32

    if output_type == "latent":
        return latents

    shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
    latents = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor
    image = vae.decode(latents, return_dict=False)[0]

    if output_type == "pil":
        from PIL import Image

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        image = [Image.fromarray(img) for img in image]

    return image


================================================
FILE: src/zimage/scheduler.py
================================================
"""FlowMatchEulerDiscreteScheduler implementation."""

# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
from dataclasses import dataclass
import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch


@dataclass
class SchedulerOutput:
    prev_sample: torch.FloatTensor


class SchedulerConfig:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def get(self, key, default=None):
        return self.__dict__.get(key, default)

    def __getattr__(self, name):
        return self.__dict__.get(name)


class FlowMatchEulerDiscreteScheduler:
    """Euler scheduler for flow matching."""

    def __init__(
        self,
        num_train_timesteps: int = 1000,
        shift: float = 1.0,
        use_dynamic_shifting: bool = False,
        **kwargs,
    ):
        self.num_train_timesteps = num_train_timesteps
        self.shift = shift
        self.use_dynamic_shifting = use_dynamic_shifting
        self.config = SchedulerConfig(
            num_train_timesteps=num_train_timesteps,
            shift=shift,
            use_dynamic_shifting=use_dynamic_shifting,
        )

        timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
        sigmas = timesteps / num_train_timesteps

        if not use_dynamic_shifting:
            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)

        self.timesteps = sigmas * num_train_timesteps
        self.sigmas = sigmas.to("cpu")
        self.sigma_min = self.sigmas[-1].item()
        self.sigma_max = self.sigmas[0].item()

        self._step_index = None
        self._begin_index = None

    def set_timesteps(
        self,
        num_inference_steps: Optional[int] = None,
        device: Union[str, torch.device] = None,
        sigmas: Optional[List[float]] = None,
        mu: Optional[float] = None,
        timesteps: Optional[List[float]] = None,
    ):
        passed_timesteps = timesteps
        if num_inference_steps is None:
            num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)

        self.num_inference_steps = num_inference_steps

        if sigmas is None:
            if timesteps is None:
                timesteps = np.linspace(
                    self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + 1
                )[:-1]
            sigmas = timesteps / self.num_train_timesteps
        else:
            sigmas = np.array(sigmas).astype(np.float32)

        if self.use_dynamic_shifting:
            sigmas = self.time_shift(mu, 1.0, sigmas)
        else:
            sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)

        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)

        if passed_timesteps is None:
            timesteps = sigmas * self.num_train_timesteps
        else:
            timesteps = torch.from_numpy(passed_timesteps).to(dtype=torch.float32, device=device)

        sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

        self.timesteps = timesteps
        self.sigmas = sigmas
        self._step_index = None
        self._begin_index = None

    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps

        indices = (schedule_timesteps == timestep).nonzero()
        pos = 1 if len(indices) > 1 else 0
        return indices[pos].item()

    def _init_step_index(self, timestep):
        if self._begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index

    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        return_dict: bool = True,
        **kwargs,
    ) -> Union[SchedulerOutput, Tuple]:
        """Predict the sample at the previous timestep."""
        if self._step_index is None:
            self._init_step_index(timestep)

        sample = sample.to(torch.float32)
        sigma_idx = self._step_index
        sigma = self.sigmas[sigma_idx]
        sigma_next = self.sigmas[sigma_idx + 1]

        dt = sigma_next - sigma
        prev_sample = sample + dt * model_output
        self._step_index += 1
        prev_sample = prev_sample.to(model_output.dtype)

        if not return_dict:
            return (prev_sample,)
        return SchedulerOutput(prev_sample=prev_sample)

    def _sigma_to_t(self, sigma):
        return sigma * self.num_train_timesteps

    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


================================================
FILE: src/zimage/transformer.py
================================================
"""Z-Image Transformer."""

import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from config import (
    ADALN_EMBED_DIM,
    FREQUENCY_EMBEDDING_SIZE,
    MAX_PERIOD,
    ROPE_AXES_DIMS,
    ROPE_AXES_LENS,
    ROPE_THETA,
    SEQ_MULTI_OF,
)


class TimestepEmbedder(nn.Module):
    def __init__(self, out_size, mid_size=None, frequency_embedding_size=FREQUENCY_EMBEDDING_SIZE):
        super().__init__()
        if mid_size is None:
            mid_size = out_size
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, mid_size, bias=True),
            nn.SiLU(),
            nn.Linear(mid_size, out_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=MAX_PERIOD):
        with torch.amp.autocast("cuda", enabled=False):
            half = dim // 2
            freqs = torch.exp(
                -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
            )
            args = t[:, None].float() * freqs[None]
            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
            if dim % 2:
                embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
            return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        weight_dtype = self.mlp[0].weight.dtype
        if weight_dtype.is_floating_point:
            t_freq = t_freq.to(weight_dtype)
        t_emb = self.mlp(t_freq)
        return t_emb


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return output * self.weight


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    with torch.amp.autocast("cuda", enabled=False):
        x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
        freqs_cis = freqs_cis.unsqueeze(2)
        x_out = torch.view_as_real(x * freqs_cis).flatten(3)
        return x_out.type_as(x_in)


class ZImageAttention(nn.Module):
    _attention_backend = None

    def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: bool = True, eps: float = 1e-5):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads

        self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])

        self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
        self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        freqs_cis: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        query = self.to_q(hidden_states)
        key = self.to_k(hidden_states)
        value = self.to_v(hidden_states)

        query = query.unflatten(-1, (self.n_heads, -1))
        key = key.unflatten(-1, (self.n_kv_heads, -1))
        value = value.unflatten(-1, (self.n_kv_heads, -1))

        if self.norm_q is not None:
            query = self.norm_q(query)
        if self.norm_k is not None:
            key = self.norm_k(key)

        if freqs_cis is not None:
            query = apply_rotary_emb(query, freqs_cis)
            key = apply_rotary_emb(key, freqs_cis)

        dtype = query.dtype
        query, key = query.to(dtype), key.to(dtype)

        # Dispatch
        from utils.attention import dispatch_attention

        hidden_states = dispatch_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend
        )

        hidden_states = hidden_states.flatten(2, 3)
        hidden_states = hidden_states.to(dtype)

        output = self.to_out[0](hidden_states)
        return output


class ZImageTransformerBlock(nn.Module):
    def __init__(
        self,
        layer_id: int,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        norm_eps: float,
        qk_norm: bool,
        modulation=True,
    ):
        super().__init__()
        self.dim = dim
        self.head_dim = dim // n_heads
        self.layer_id = layer_id
        self.modulation = modulation

        self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
        self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))

        self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
        self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
        self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
        self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)

        if modulation:
            self.adaLN_modulation = nn.ModuleList([nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)])

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        freqs_cis: torch.Tensor,
        adaln_input: Optional[torch.Tensor] = None,
    ):
        if self.modulation:
            assert adaln_input is not None
            scale_msa, gate_msa, scale_mlp, gate_mlp = (
                self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
            )
            gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
            scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp

            attn_out = self.attention(
                self.attention_norm1(x) * scale_msa,
                attention_mask=attn_mask,
                freqs_cis=freqs_cis,
            )
            x = x + gate_msa * self.attention_norm2(attn_out)
            x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
        else:
            attn_out = self.attention(
                self.attention_norm1(x),
                attention_mask=attn_mask,
                freqs_cis=freqs_cis,
            )
            x = x + self.attention_norm2(attn_out)
            x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))

        return x


class FinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
        )

    def forward(self, x, c):
        scale = 1.0 + self.adaLN_modulation(c)
        x = self.norm_final(x) * scale.unsqueeze(1)
        x = self.linear(x)
        return x


class RopeEmbedder:
    def __init__(
        self,
        theta: float = ROPE_THETA,
        axes_dims: List[int] = ROPE_AXES_DIMS,
        axes_lens: List[int] = ROPE_AXES_LENS,
    ):
        self.theta = theta
        self.axes_dims = axes_dims
        self.axes_lens = axes_lens
        assert len(axes_dims) == len(axes_lens)
        self.freqs_cis = None

    @staticmethod
    def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = ROPE_THETA):
        with torch.device("cpu"):
            freqs_cis = []
            for i, (d, e) in enumerate(zip(dim, end)):
                freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
                timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
                freqs = torch.outer(timestep, freqs).float()
                freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
                freqs_cis.append(freqs_cis_i)
            return freqs_cis

    def __call__(self, ids: torch.Tensor):
        assert ids.ndim == 2
        assert ids.shape[-1] == len(self.axes_dims)
        device = ids.device

        if self.freqs_cis is None:
            self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
            self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
        else:
            if self.freqs_cis[0].device != device:
                self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

        result = []
        for i in range(len(self.axes_dims)):
            index = ids[:, i]
            result.append(self.freqs_cis[i][index])
        return torch.cat(result, dim=-1)


class ZImageTransformer2DModel(nn.Module):
    def __init__(
        self,
        all_patch_size=(2,),
        all_f_patch_size=(1,),
        in_channels=16,
        dim=3840,
        n_layers=30,
        n_refiner_layers=2,
        n_heads=30,
        n_kv_heads=30,
        norm_eps=1e-5,
        qk_norm=True,
        cap_feat_dim=2560,
        rope_theta=ROPE_THETA,
        t_scale=1000.0,
        axes_dims=ROPE_AXES_DIMS,
        axes_lens=ROPE_AXES_LENS,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.all_patch_size = all_patch_size
        self.all_f_patch_size = all_f_patch_size
        self.dim = dim
        self.n_heads = n_heads
        self.rope_theta = rope_theta
        self.t_scale = t_scale

        assert len(all_patch_size) == len(all_f_patch_size)

        all_x_embedder = {}
        all_final_layer = {}
        for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
            x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
            all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
            final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
            all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer

        self.all_x_embedder = nn.ModuleDict(all_x_embedder)
        self.all_final_layer = nn.ModuleDict(all_final_layer)

        self.noise_refiner = nn.ModuleList(
            [
                ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)
                for layer_id in range(n_refiner_layers)
            ]
        )

        self.context_refiner = nn.ModuleList(
            [
                ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)
                for layer_id in range(n_refiner_layers)
            ]
        )

        self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
        self.cap_embedder = nn.Sequential(
            RMSNorm(cap_feat_dim, eps=norm_eps),
            nn.Linear(cap_feat_dim, dim, bias=True),
        )

        self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
        self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))

        self.layers = nn.ModuleList(
            [
                ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
                for layer_id in range(n_layers)
            ]
        )

        head_dim = dim // n_heads
        assert head_dim == sum(axes_dims)
        self.axes_dims = axes_dims
        self.axes_lens = axes_lens

        self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)

    def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
        pH = pW = patch_size
        pF = f_patch_size
        bsz = len(x)
        assert len(size) == bsz
        for i in range(bsz):
            F, H, W = size[i]
            ori_len = (F // pF) * (H // pH) * (W // pW)
            x[i] = (
                x[i][:ori_len]
                .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
                .permute(6, 0, 3, 1, 4, 2, 5)
                .reshape(self.out_channels, F, H, W)
            )
        return x

    @staticmethod
    def create_coordinate_grid(size, start=None, device=None):
        if start is None:
            start = (0 for _ in size)
        axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
        grids = torch.meshgrid(axes, indexing="ij")
        return torch.stack(grids, dim=-1)

    def patchify_and_embed(
        self,
        all_image: List[torch.Tensor],
        all_cap_feats: List[torch.Tensor],
        patch_size: int,
        f_patch_size: int,
    ):
        pH = pW = patch_size
        pF = f_patch_size
        device = all_image[0].device

        all_image_out = []
        all_image_size = []
        all_image_pos_ids = []
        all_image_pad_mask = []
        all_cap_pos_ids = []
        all_cap_pad_mask = []
        all_cap_feats_out = []

        for _, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
            cap_ori_len = len(cap_feat)
            cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
            cap_padded_pos_ids = self.create_coordinate_grid(
                size=(cap_ori_len + cap_padding_len, 1, 1),
                start=(1, 0, 0),
                device=device,
            ).flatten(0, 2)
            all_cap_pos_ids.append(cap_padded_pos_ids)
            # pad mask
            all_cap_pad_mask.append(
                torch.cat(
                    [
                        torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
                        torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
                    ],
                    dim=0,
                )
                if cap_padding_len > 0
                else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
            )
            # padded feature
            all_cap_feats_out.append(
                torch.cat(
                    [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
                    dim=0,
                )
                if cap_padding_len > 0
                else cap_feat
            )

            C, F, H, W = image.size()
            all_image_size.append((F, H, W))
            F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW

            image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
            image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)

            image_ori_len = len(image)
            image_padding_len = (-image_ori_len) % SEQ_MULTI_OF

            image_ori_pos_ids = self.create_coordinate_grid(
                size=(F_tokens, H_tokens, W_tokens),
                start=(cap_ori_len + cap_padding_len + 1, 0, 0),
                device=device,
            ).flatten(0, 2)
            image_padded_pos_ids = torch.cat(
                [
                    image_ori_pos_ids,
                    self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
                    .flatten(0, 2)
                    .repeat(image_padding_len, 1),
                ],
                dim=0,
            )
            all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
            # pad mask
            image_pad_mask = torch.cat(
                [
                    torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
                    torch.ones((image_padding_len,), dtype=torch.bool, device=device),
                ],
                dim=0,
            )
            all_image_pad_mask.append(
                image_pad_mask
                if image_padding_len > 0
                else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
            )
            # padded feature
            image_padded_feat = torch.cat(
                [image, image[-1:].repeat(image_padding_len, 1)],
                dim=0,
            )
            all_image_out.append(image_padded_feat if image_padding_len > 0 else image)

        return (
            all_image_out,
            all_cap_feats_out,
            all_image_size,
            all_image_pos_ids,
            all_cap_pos_ids,
            all_image_pad_mask,
            all_cap_pad_mask,
        )

    def forward(
        self,
        x: List[torch.Tensor],
        t,
        cap_feats: List[torch.Tensor],
        patch_size=2,
        f_patch_size=1,
    ):
        assert patch_size in self.all_patch_size
        assert f_patch_size in self.all_f_patch_size

        bsz = len(x)
        device = x[0].device
        t = t * self.t_scale
        t = self.t_embedder(t)

        (
            x,
            cap_feats,
            x_size,
            x_pos_ids,
            cap_pos_ids,
            x_inner_pad_mask,
            cap_inner_pad_mask,
        ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)

        x_item_seqlens = [len(_) for _ in x]
        assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
        x_max_item_seqlen = max(x_item_seqlens)

        x = torch.cat(x, dim=0)
        x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)

        adaln_input = t.type_as(x)
        x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
        x = list(x.split(x_item_seqlens, dim=0))
        x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))

        x = pad_sequence(x, batch_first=True, padding_value=0.0)
        x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
        # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
        x_freqs_cis = x_freqs_cis[:, : x.shape[1]]

        x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
        for i, seq_len in enumerate(x_item_seqlens):
            x_attn_mask[i, :seq_len] = 1

        for layer in self.noise_refiner:
            x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)

        cap_item_seqlens = [len(_) for _ in cap_feats]
        assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
        cap_max_item_seqlen = max(cap_item_seqlens)

        cap_feats = torch.cat(cap_feats, dim=0)
        cap_feats = self.cap_embedder(cap_feats)
        cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
        cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
        cap_freqs_cis = list(
            self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
        )

        cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
        cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
        cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]  # same for dynamo compatibility

        cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
        for i, seq_len in enumerate(cap_item_seqlens):
            cap_attn_mask[i, :seq_len] = 1

        for layer in self.context_refiner:
            cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)

        unified = []
        unified_freqs_cis = []
        for i in range(bsz):
            x_len = x_item_seqlens[i]
            cap_len = cap_item_seqlens[i]
            unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
            unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
        unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
        assert unified_item_seqlens == [len(_) for _ in unified]
        unified_max_item_seqlen = max(unified_item_seqlens)

        unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
        unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
        unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
        for i, seq_len in enumerate(unified_item_seqlens):
            unified_attn_mask[i, :seq_len] = 1

        for layer in self.layers:
            unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)

        unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
        unified = list(unified.unbind(dim=0))
        x = self.unpatchify(unified, x_size, patch_size, f_patch_size)

        return x, {}
Download .txt
gitextract_m78kzmq8/

├── .gitignore
├── LICENSE
├── README.md
├── batch_inference.py
├── inference.py
├── pyproject.toml
└── src/
    ├── __init__.py
    ├── config/
    │   ├── __init__.py
    │   ├── inference.py
    │   ├── manifests/
    │   │   ├── README.md
    │   │   └── z-image-turbo.txt
    │   └── model.py
    ├── tools/
    │   ├── __init__.py
    │   └── generate_manifest.py
    ├── utils/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── helpers.py
    │   ├── import_utils.py
    │   └── loader.py
    └── zimage/
        ├── __init__.py
        ├── autoencoder.py
        ├── pipeline.py
        ├── scheduler.py
        └── transformer.py
Download .txt
SYMBOL INDEX (125 symbols across 11 files)

FILE: batch_inference.py
  function read_prompts (line 14) | def read_prompts(path: str) -> list[str]:
  function slugify (line 30) | def slugify(text: str, max_len: int = 60) -> str:
  function select_device (line 38) | def select_device() -> str:
  function main (line 58) | def main():

FILE: inference.py
  function main (line 14) | def main():

FILE: src/tools/generate_manifest.py
  function compute_md5 (line 15) | def compute_md5(file_path: Path, chunk_size: int = 8192) -> str:
  function get_essential_files (line 24) | def get_essential_files(model_dir: Path) -> List[Path]:
  function main (line 51) | def main():

FILE: src/utils/attention.py
  class AttentionBackend (line 51) | class AttentionBackend(str, Enum):
    method print_available_backends (line 67) | def print_available_backends(cls):
  function register_backend (line 77) | def register_backend(name: str, constraints: Optional[List[Callable]] = ...
  function _check_device_cuda (line 87) | def _check_device_cuda(query: torch.Tensor, **kwargs) -> None:
  function _check_qkv_dtype_bf16_or_fp16 (line 92) | def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:
  function _check_device_mps (line 97) | def _check_device_mps(query: torch.Tensor, **kwargs) -> None:
  function _process_mask (line 102) | def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
  function _normalize_attn_mask (line 119) | def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_l...
  function _prepare_for_flash_attn_varlen_without_mask (line 148) | def _prepare_for_flash_attn_varlen_without_mask(
  function _prepare_for_flash_attn_varlen_with_mask (line 164) | def _prepare_for_flash_attn_varlen_with_mask(
  function _prepare_for_flash_attn_varlen (line 184) | def _prepare_for_flash_attn_varlen(
  function _flash_attention (line 197) | def _flash_attention(
  function _flash_varlen_attention (line 223) | def _flash_varlen_attention(
  function _flash_attention_3 (line 277) | def _flash_attention_3(
  function _flash_varlen_attention_3 (line 309) | def _flash_varlen_attention_3(
  function _mps_flash_attention (line 373) | def _mps_flash_attention(
  function _native_attention_wrapper (line 410) | def _native_attention_wrapper(
  function _native_flash_attention (line 440) | def _native_flash_attention(
  function _math_attention (line 462) | def _math_attention(*args, **kwargs):
  function _native_attention (line 467) | def _native_attention(*args, **kwargs):
  function dispatch_attention (line 471) | def dispatch_attention(
  function set_attention_backend (line 508) | def set_attention_backend(backend: Union[str, AttentionBackend, None]):

FILE: src/utils/helpers.py
  function format_bytes (line 14) | def format_bytes(size: float) -> str:
  function print_memory_stats (line 28) | def print_memory_stats(stage: str) -> None:
  function compute_file_md5 (line 52) | def compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:
  function load_manifest (line 61) | def load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:
  function verify_file_integrity (line 94) | def verify_file_integrity(
  function ensure_model_weights (line 135) | def ensure_model_weights(

FILE: src/utils/import_utils.py
  function is_flash_attn_available (line 6) | def is_flash_attn_available():
  function is_flash_attn_3_available (line 10) | def is_flash_attn_3_available():
  function is_torch_version (line 14) | def is_torch_version(operator: str, version: str):

FILE: src/utils/loader.py
  function load_config (line 45) | def load_config(config_path: str) -> dict:
  function load_sharded_safetensors (line 50) | def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dty...
  function load_from_local_dir (line 80) | def load_from_local_dir(

FILE: src/zimage/autoencoder.py
  class AutoencoderKLOutput (line 12) | class AutoencoderKLOutput:
  class AutoencoderConfig (line 16) | class AutoencoderConfig:
    method __init__ (line 17) | def __init__(self, **kwargs):
    method get (line 20) | def get(self, key, default=None):
    method __getattr__ (line 23) | def __getattr__(self, name):
  function swish (line 27) | def swish(x):
  class ResnetBlock2D (line 31) | class ResnetBlock2D(nn.Module):
    method __init__ (line 32) | def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_c...
    method forward (line 51) | def forward(self, input_tensor, temb=None):
  class Attention (line 69) | class Attention(nn.Module):
    method __init__ (line 70) | def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps...
    method forward (line 81) | def forward(self, hidden_states):
  class Downsample2D (line 101) | class Downsample2D(nn.Module):
    method __init__ (line 102) | def __init__(self, channels, with_conv=True, out_channels=None, paddin...
    method forward (line 109) | def forward(self, hidden_states):
  class Upsample2D (line 116) | class Upsample2D(nn.Module):
    method __init__ (line 117) | def __init__(self, channels, with_conv=True, out_channels=None):
    method forward (line 124) | def forward(self, hidden_states):
  class DownEncoderBlock2D (line 131) | class DownEncoderBlock2D(nn.Module):
    method __init__ (line 132) | def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps...
    method forward (line 147) | def forward(self, hidden_states):
  class UpDecoderBlock2D (line 160) | class UpDecoderBlock2D(nn.Module):
    method __init__ (line 161) | def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps...
    method forward (line 174) | def forward(self, hidden_states):
  class UNetMidBlock2D (line 185) | class UNetMidBlock2D(nn.Module):
    method __init__ (line 186) | def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, att...
    method forward (line 196) | def forward(self, hidden_states):
  class Encoder (line 204) | class Encoder(nn.Module):
    method __init__ (line 205) | def __init__(
    method forward (line 244) | def forward(self, x):
  class Decoder (line 255) | class Decoder(nn.Module):
    method __init__ (line 256) | def __init__(
    method forward (line 293) | def forward(self, x):
  class AutoencoderKL (line 304) | class AutoencoderKL(nn.Module):
    method __init__ (line 305) | def __init__(
    method dtype (line 357) | def dtype(self):
    method decode (line 360) | def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Au...

FILE: src/zimage/pipeline.py
  function calculate_shift (line 23) | def calculate_shift(
  function retrieve_timesteps (line 36) | def retrieve_timesteps(
  function generate (line 67) | def generate(

FILE: src/zimage/scheduler.py
  class SchedulerOutput (line 13) | class SchedulerOutput:
  class SchedulerConfig (line 17) | class SchedulerConfig:
    method __init__ (line 18) | def __init__(self, **kwargs):
    method get (line 21) | def get(self, key, default=None):
    method __getattr__ (line 24) | def __getattr__(self, name):
  class FlowMatchEulerDiscreteScheduler (line 28) | class FlowMatchEulerDiscreteScheduler:
    method __init__ (line 31) | def __init__(
    method set_timesteps (line 62) | def set_timesteps(
    method index_for_timestep (line 104) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 112) | def _init_step_index(self, timestep):
    method step (line 120) | def step(
    method _sigma_to_t (line 146) | def _sigma_to_t(self, sigma):
    method time_shift (line 149) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):

FILE: src/zimage/transformer.py
  class TimestepEmbedder (line 22) | class TimestepEmbedder(nn.Module):
    method __init__ (line 23) | def __init__(self, out_size, mid_size=None, frequency_embedding_size=F...
    method timestep_embedding (line 35) | def timestep_embedding(t, dim, max_period=MAX_PERIOD):
    method forward (line 47) | def forward(self, t):
  class RMSNorm (line 56) | class RMSNorm(nn.Module):
    method __init__ (line 57) | def __init__(self, dim: int, eps: float = 1e-5):
    method forward (line 62) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class FeedForward (line 67) | class FeedForward(nn.Module):
    method __init__ (line 68) | def __init__(self, dim: int, hidden_dim: int):
    method forward (line 74) | def forward(self, x):
  function apply_rotary_emb (line 78) | def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> tor...
  class ZImageAttention (line 86) | class ZImageAttention(nn.Module):
    method __init__ (line 89) | def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: b...
    method forward (line 103) | def forward(
  class ZImageTransformerBlock (line 143) | class ZImageTransformerBlock(nn.Module):
    method __init__ (line 144) | def __init__(
    method forward (line 171) | def forward(
  class FinalLayer (line 205) | class FinalLayer(nn.Module):
    method __init__ (line 206) | def __init__(self, hidden_size, out_channels):
    method forward (line 215) | def forward(self, x, c):
  class RopeEmbedder (line 222) | class RopeEmbedder:
    method __init__ (line 223) | def __init__(
    method precompute_freqs_cis (line 236) | def precompute_freqs_cis(dim: List[int], end: List[int], theta: float ...
    method __call__ (line 247) | def __call__(self, ids: torch.Tensor):
  class ZImageTransformer2DModel (line 266) | class ZImageTransformer2DModel(nn.Module):
    method __init__ (line 267) | def __init__(
    method unpatchify (line 345) | def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_s...
    method create_coordinate_grid (line 362) | def create_coordinate_grid(size, start=None, device=None):
    method patchify_and_embed (line 369) | def patchify_and_embed(
    method forward (line 474) | def forward(
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (148K chars).
[
  {
    "path": ".gitignore",
    "chars": 4802,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[codz]\n*$py.class\n\n# C extensions\n*.so\noutputs/\nprompts/\n# Dist"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 21065,
    "preview": "<h1 align=\"center\">⚡️- Image<br><sub><sup>An Efficient Image Generation Foundation Model with Single-Stream Diffusion Tr"
  },
  {
    "path": "batch_inference.py",
    "chars": 3018,
    "preview": "\"\"\"Batch prompt inference for Z-Image.\"\"\"\n\nimport os\nfrom pathlib import Path\nimport time\n\nimport torch\n\nfrom inference "
  },
  {
    "path": "inference.py",
    "chars": 2651,
    "preview": "\"\"\"Z-Image PyTorch Native Inference.\"\"\"\n\nimport os\nimport time\nimport warnings\n\nimport torch\n\nwarnings.filterwarnings(\"i"
  },
  {
    "path": "pyproject.toml",
    "chars": 532,
    "preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"zimage-native\""
  },
  {
    "path": "src/__init__.py",
    "chars": 245,
    "preview": "\"\"\"Z-Image Native Implementation.\"\"\"\n\nfrom .utils import load_from_local_dir\nfrom .zimage import ZImageTransformer2DMode"
  },
  {
    "path": "src/config/__init__.py",
    "chars": 2523,
    "preview": "\"\"\"Z-Image Configuration.\"\"\"\n\nfrom .inference import (\n    DEFAULT_CFG_TRUNCATION,\n    DEFAULT_GUIDANCE_SCALE,\n    DEFAU"
  },
  {
    "path": "src/config/inference.py",
    "chars": 216,
    "preview": "\"\"\"Inference-specific configuration for Z-Image.\"\"\"\n\nDEFAULT_HEIGHT = 1024\nDEFAULT_WIDTH = 1024\nDEFAULT_INFERENCE_STEPS "
  },
  {
    "path": "src/config/manifests/README.md",
    "chars": 1880,
    "preview": "# Model Manifests\n\nThis directory contains manifest files for different Z-Image model variants.\n\n## Purpose\n\nManifest fi"
  },
  {
    "path": "src/config/manifests/z-image-turbo.txt",
    "chars": 1290,
    "preview": "# Z-Image Model Manifest\n# Format: <md5hash>  <filepath>\n# Generated automatically - DO NOT edit manually\n\n5e3226ed72a9a"
  },
  {
    "path": "src/config/model.py",
    "chars": 1149,
    "preview": "\"\"\"Model configuration constants for Z-Image.\"\"\"\n\nADALN_EMBED_DIM = 256\nSEQ_MULTI_OF = 32\n\nROPE_THETA = 256.0\nROPE_AXES_"
  },
  {
    "path": "src/tools/__init__.py",
    "chars": 169,
    "preview": "\"\"\"Tools for Z-Image model management.\"\"\"\n\nfrom .generate_manifest import compute_md5, get_essential_files\n\n__all__ = [\n"
  },
  {
    "path": "src/tools/generate_manifest.py",
    "chars": 4246,
    "preview": "#!/usr/bin/env python3\n\"\"\"Generate manifest file with MD5 checksums for model weights.\n\nUsage:\n    python -m tools.gener"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 424,
    "preview": "\"\"\"Utilities for Z-Image.\"\"\"\n\nfrom .attention import AttentionBackend, dispatch_attention, set_attention_backend\nfrom .h"
  },
  {
    "path": "src/utils/attention.py",
    "chars": 17439,
    "preview": "\"\"\"Attention backend utilities for Z-Image.\"\"\"\n\n# Modified from https://github.com/huggingface/diffusers/blob/main/src/d"
  },
  {
    "path": "src/utils/helpers.py",
    "chars": 9181,
    "preview": "\"\"\"Helper utilities for Z-Image.\"\"\"\n\nimport hashlib\nimport json\nfrom pathlib import Path\nfrom typing import Optional, Li"
  },
  {
    "path": "src/utils/import_utils.py",
    "chars": 907,
    "preview": "import importlib.util\n\nimport torch\n\n\ndef is_flash_attn_available():\n    return importlib.util.find_spec(\"flash_attn\") i"
  },
  {
    "path": "src/utils/loader.py",
    "chars": 8454,
    "preview": "\"\"\"Model loading utilities for Z-Image components.\"\"\"\n\nimport json\nimport os\nfrom pathlib import Path\nimport sys\nfrom ty"
  },
  {
    "path": "src/zimage/__init__.py",
    "chars": 190,
    "preview": "\"\"\"Z-Image PyTorch Native Implementation.\"\"\"\n\nfrom .pipeline import generate\nfrom .transformer import ZImageTransformer2"
  },
  {
    "path": "src/zimage/autoencoder.py",
    "chars": 13138,
    "preview": "\"\"\"AutoencoderKL implementation compatible with diffusers weights.\"\"\"\n\n# Modified from https://github.com/black-forest-l"
  },
  {
    "path": "src/zimage/pipeline.py",
    "chars": 10527,
    "preview": "\"\"\"Z-Image Pipeline.\"\"\"\n\nimport inspect\nfrom typing import List, Optional, Union\n\nfrom loguru import logger\nimport torch"
  },
  {
    "path": "src/zimage/scheduler.py",
    "chars": 5030,
    "preview": "\"\"\"FlowMatchEulerDiscreteScheduler implementation.\"\"\"\n\n# Modified from https://github.com/huggingface/diffusers/blob/mai"
  },
  {
    "path": "src/zimage/transformer.py",
    "chars": 21162,
    "preview": "\"\"\"Z-Image Transformer.\"\"\"\n\nimport math\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimp"
  }
]

About this extraction

This page contains the full source code of the Tongyi-MAI/Z-Image GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (138.3 KB), approximately 36.1k tokens, and a symbol index with 125 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!