Repository: Tongyi-MAI/Z-Image
Branch: main
Commit: 26f23eda626f
Files: 24
Total size: 138.3 KB
Directory structure:
gitextract_m78kzmq8/
├── .gitignore
├── LICENSE
├── README.md
├── batch_inference.py
├── inference.py
├── pyproject.toml
└── src/
├── __init__.py
├── config/
│ ├── __init__.py
│ ├── inference.py
│ ├── manifests/
│ │ ├── README.md
│ │ └── z-image-turbo.txt
│ └── model.py
├── tools/
│ ├── __init__.py
│ └── generate_manifest.py
├── utils/
│ ├── __init__.py
│ ├── attention.py
│ ├── helpers.py
│ ├── import_utils.py
│ └── loader.py
└── zimage/
├── __init__.py
├── autoencoder.py
├── pipeline.py
├── scheduler.py
└── transformer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
outputs/
prompts/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Z-Image
ckpts/
.isort.cfg
.pre-commit-config.yaml
*.DS_Store
# Ignore generated images
/*.png
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
⚡️- Image
An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer
[](https://tongyi-mai.github.io/Z-Image-blog/)
[](https://huggingface.co/Tongyi-MAI/Z-Image)
[](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo)
[](https://huggingface.co/spaces/Tongyi-MAI/Z-Image)
[](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)
[](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)
[](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)
[](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)
[](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster)
[](assets/Z-Image-Gallery.pdf)
[](https://modelscope.cn/studios/Tongyi-MAI/Z-Image-Gallery/summary)

Welcome to the official repository for the Z-Image(造相)project!
## ✨ Z-Image
Z-Image is a powerful and highly efficient image generation model family with **6B** parameters. Currently there are four variants:
- 🚀 **Z-Image-Turbo** – A distilled version of Z-Image that matches or exceeds leading competitors with only **8 NFEs** (Number of Function Evaluations). It offers **⚡️sub-second inference latency⚡️** on enterprise-grade H800 GPUs and fits comfortably within **16G VRAM consumer devices**. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.
- 🎨 **Z-Image** – The foundation model behind Z-Image-Turbo. Z-Image focuses on **high-quality generation**, **rich aesthetics**, **strong diversity**, and **controllability**, well-suited for creative generation, **fine-tuning**, and downstream development. It supports a wide range of artistic styles, effective negative prompting, and high diversity across identities, poses, compositions, and layouts.
- 🧱 **Z-Image-Omni-Base** – The versatile foundation model capable of both **generation and editing tasks**. By releasing this checkpoint, we aim to unlock the full potential for community-driven fine-tuning and custom development, providing the most "raw" and diverse starting point for the open-source community.
- ✍️ **Z-Image-Edit** – A variant fine-tuned on Z-Image specifically for image editing tasks. It supports creative image-to-image generation with impressive instruction-following capabilities, allowing for precise edits based on natural language prompts.
### 📣 News
* **[2026-01-27]** 🔥 **Z-Image is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image). Try our [online demo](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster)!
* **[2025-12-08]** 🏆 Z-Image-Turbo ranked 8th overall on the **Artificial Analysis Text-to-Image Leaderboard**, making it the 🥇 #1 open-source model! [Check out the full leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image).
* **[2025-12-01]** 🎉 Our technical report for Z-Image is now available on [arXiv](https://arxiv.org/abs/2511.22699).
* **[2025-11-26]** 🔥 **Z-Image-Turbo is released!** We have released the model checkpoint on [Hugging Face](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) and [ModelScope](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo). Try our [online demo](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo)!
### 📥 Model Zoo
| Model | Pre-Training | SFT | RL | Step | CFG | Task | Visual Quality | Diversity | Fine-Tunability | Hugging Face | ModelScope |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **Z-Image-Omni-Base** | ✅ | ❌ | ❌ | 50 | ✅ | Gen. / Editing | Medium | High | Easy | *To be released* | *To be released* |
| **Z-Image** | ✅ | ✅ | ❌ | 50 | ✅ | Gen. | High | Medium | Easy | [](https://huggingface.co/Tongyi-MAI/Z-Image)
[](https://huggingface.co/spaces/Tongyi-MAI/Z-Image) | [](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)
[](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=569345&modelType=Checkpoint&sdVersion=Z_IMAGE&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image%3Frevision%3Dmaster) |
| **Z-Image-Turbo** | ✅ | ✅ | ✅ | 8 | ❌ | Gen. | Very High | Low | N/A | [](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo)
[](https://huggingface.co/spaces/Tongyi-MAI/Z-Image-Turbo) | [](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)
[](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=469191&modelType=Checkpoint&sdVersion=Z_IMAGE_TURBO&modelUrl=modelscope%3A%2F%2FTongyi-MAI%2FZ-Image-Turbo%3Frevision%3Dmaster) |
| **Z-Image-Edit** | ✅ | ✅ | ❌ | 50 | ✅ | Editing | High | Medium | Easy | *To be released* | *To be released* |
The figure below illustrates at which training stage each model is produced.

### 🖼️ Showcase
📸 **Photorealistic Quality**: **Z-Image-Turbo** delivers strong photorealistic image generation while maintaining excellent aesthetic quality.

📖 **Accurate Bilingual Text Rendering**: **Z-Image-Turbo** excels at accurately rendering complex Chinese and English text.

💡 **Prompt Enhancing & Reasoning**: Prompt Enhancer empowers the model with reasoning capabilities, enabling it to transcend surface-level descriptions and tap into underlying world knowledge.

🧠 **Creative Image Editing**: **Z-Image-Edit** shows a strong understanding of bilingual editing instructions, enabling imaginative and flexible image transformations.

### 🏗️ Model Architecture
We adopt a **Scalable Single-Stream DiT** (S3-DiT) architecture. In this setup, text, visual semantic tokens, and image VAE tokens are concatenated at the sequence level to serve as a unified input stream, maximizing parameter efficiency compared to dual-stream approaches.

### 📈 Performance
Z-Image-Turbo's performance has been validated on multiple independent benchmarks, where it consistently demonstrates state-of-the-art results, especially as the leading open-source model.
#### Artificial Analysis Text-to-Image Leaderboard
On the highly competitive [Artificial Analysis Leaderboard](https://artificialanalysis.ai/image/leaderboard/text-to-image), Z-Image-Turbo ranked **8th overall** and secured the top position as the 🥇 #1 Open-Source Model, outperforming all other open-source alternatives.

Artificial Analysis Leaderboard

Artificial Analysis Leaderboard (Open-Source Model Only)
#### Alibaba AI Arena Text-to-Image Leaderboard
According to the Elo-based Human Preference Evaluation on [*Alibaba AI Arena*](https://aiarena.alibaba-inc.com/corpora/arena/leaderboard?arenaType=T2I), Z-Image-Turbo also achieves state-of-the-art results among open-source models and shows highly competitive performance against leading proprietary models.

Alibaba AI Arena Text-to-Image Leaderboard
### 🚀 Quick Start
#### (1) PyTorch Native Inference
Build a virtual environment you like and then install the dependencies:
```bash
pip install -e .
```
Then run the following code to generate an image:
```bash
python inference.py
```
#### (2) Diffusers Inference
Install the latest version of diffusers, use the following command:
Click here for details for why you need to install diffusers from source
We have submitted two pull requests ([#12703](https://github.com/huggingface/diffusers/pull/12703) and [#12715](https://github.com/huggingface/diffusers/pull/12715)) to the 🤗 diffusers repository to add support for Z-Image. Both PRs have been merged into the latest official diffusers release.
Therefore, you need to install diffusers from source for the latest features and Z-Image support.
```bash
pip install git+https://github.com/huggingface/diffusers
```
Z-Image-Turbo - Click to expand
Then, try the following code to generate an image:
```python
import torch
from diffusers import ZImagePipeline
# 1. Load the pipeline
# Use bfloat16 for optimal performance on supported GPUs
pipe = ZImagePipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
pipe.to("cuda")
# [Optional] Attention Backend
# Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported:
# pipe.transformer.set_attention_backend("flash") # Enable Flash-Attention-2
# pipe.transformer.set_attention_backend("_flash_3") # Enable Flash-Attention-3
# [Optional] Model Compilation
# Compiling the DiT model accelerates inference, but the first run will take longer to compile.
# pipe.transformer.compile()
# [Optional] CPU Offloading
# Enable CPU offloading for memory-constrained devices.
# pipe.enable_model_cpu_offload()
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
# 2. Generate Image
image = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=9, # This actually results in 8 DiT forwards
guidance_scale=0.0, # Guidance should be 0 for the Turbo models
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("example.png")
```
Z-Image - Click to expand
Recommended Parameters:
- **Resolution:** 512×512 to 2048×2048 (total pixel area, any aspect ratio)
- **Guidance scale:** 3.0 – 5.0
- **Inference steps:** 28 – 50
- **Negative prompts:** Strongly recommended for better control
- **CFG normalization:** `False` for general stylism, `True` for realism
Then, try the following code to generate an image:
```python
import torch
from diffusers import ZImagePipeline
# Load the pipeline
pipe = ZImagePipeline.from_pretrained(
"Tongyi-MAI/Z-Image",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
pipe.to("cuda")
# Generate image
prompt = "两名年轻亚裔女性紧密站在一起,背景为朴素的灰色纹理墙面,可能是室内地毯地面。左侧女性留着长卷发,身穿藏青色毛衣,左袖有奶油色褶皱装饰,内搭白色立领衬衫,下身白色裤子;佩戴小巧金色耳钉,双臂交叉于背后。右侧女性留直肩长发,身穿奶油色卫衣,胸前印有"Tun the tables"字样,下方为"New ideas",搭配白色裤子;佩戴银色小环耳环,双臂交叉于胸前。两人均面带微笑直视镜头。照片,自然光照明,柔和阴影,以藏青、奶油白为主的中性色调,休闲时尚摄影,中等景深,面部和上半身对焦清晰,姿态放松,表情友好,室内环境,地毯地面,纯色背景。"
negative_prompt = "" # Optional, but would be powerful when you want to remove some unwanted content
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=1280,
width=720,
cfg_normalization=False,
num_inference_steps=50,
guidance_scale=4,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("example.png")
```
## 🔬 Decoupled-DMD: The Acceleration Magic Behind Z-Image
[](https://arxiv.org/abs/2511.22677)
Decoupled-DMD is the core few-step distillation algorithm that empowers the 8-step Z-Image model.
Our core insight in Decoupled-DMD is that the success of existing DMD (Distribution Matching Distillation) methods is the result of two independent, collaborating mechanisms:
- **CFG Augmentation (CA)**: The primary **engine** 🚀 driving the distillation process, a factor largely overlooked in previous work.
- **Distribution Matching (DM)**: Acts more as a **regularizer** ⚖️, ensuring the stability and quality of the generated output.
By recognizing and decoupling these two mechanisms, we were able to study and optimize them in isolation. This ultimately motivated us to develop an improved distillation process that significantly enhances the performance of few-step generation.

## 🤖 DMDR: Fusing DMD with Reinforcement Learning
[](https://arxiv.org/abs/2511.13649)
Building upon the strong foundation of Decoupled-DMD, our 8-step Z-Image model has already demonstrated exceptional capabilities. To achieve further improvements in terms of semantic alignment, aesthetic quality, and structural coherence—while producing images with richer high-frequency details—we present **DMDR**.
Our core insight behind DMDR is that Reinforcement Learning (RL) and Distribution Matching Distillation (DMD) can be synergistically integrated during the post-training of few-step models. We demonstrate that:
- **RL Unlocks the Performance of DMD** 🚀
- **DMD Effectively Regularizes RL** ⚖️

## 🎉 Community Works
- [Cache-DiT](https://github.com/vipshop/cache-dit) provides inference acceleration for **Z-Image** and **Z-Image-ControlNet** via DBCache, Context Parallelism and Tensor Parallelism. It achieves nearly **4x** speedup on 4 GPUs with negligible precision loss. Please visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples) for more details.
- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).
- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) is a pure C++ diffusion model inference engine that supports fast and memory-efficient Z-Image inference across multiple platforms (CUDA, Vulkan, etc.). You can use stable-diffusion.cpp to generate images with Z-Image on machines with as little as **4GB** of VRAM. For more information, please refer to [How to Use Z‐Image on a GPU with Only 4GB VRAM](https://github.com/leejet/stable-diffusion.cpp/wiki/How-to-Use-Z%E2%80%90Image-on-a-GPU-with-Only-4GB-VRAM).
- [LeMiCa](https://github.com/UnicomAI/LeMiCa) provides a training-free, timestep-level acceleration method that conveniently speeds up Z-Image inference. For more details, see [LeMiCa4Z-Image](https://github.com/UnicomAI/LeMiCa/tree/main/LeMiCa4Z-Image).
- [ComfyUI ZImageLatent](https://github.com/HellerCommaA/ComfyUI-ZImageLatent) provdes an easy to use latent of the official Z-Image resolutions.
- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) has provided more support for Z-Image, including LoRA training, full training, distillation training, and low-VRAM inference. Please refer to the [document](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Model_Details/Z-Image.md) of DiffSynth-Studio.
- [vllm-omni](https://github.com/vllm-project/vllm-omni), a framework that extends its support for omni-modality model fast inference and serving, now [supports](https://github.com/vllm-project/vllm-omni/blob/main/docs/models/supported_models.md) Z-Image.
- [SGLang-Diffusion](https://lmsys.org/blog/2025-11-07-sglang-diffusion/) brings SGLang's state-of-the-art performance to accelerate image and video generation for diffusion models, now [supporting](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py) Z-Image.
- [Candle](https://github.com/huggingface/candle) is a minimalist machine learning (ML) framework launched by Huggingface for Rust, which now [supports](https://github.com/huggingface/candle/pull/3261) Z-Image.
- [MeanCache](https://github.com/UnicomAI/MeanCache), a training-free inference acceleration method for Flow Matching models by China Unicom Data Science and Artificial Intelligence Research Institute. Delivers up to **3.7x** speedup for **Z-Image** generation with plug-and-play integration while preserving output quality.
## 🚀 Star History
[](https://www.star-history.com/#Tongyi-MAI/Z-Image&type=date&legend=top-left)
## 📜 Citation
If you find our work useful in your research, please consider citing:
```bibtex
@article{team2025zimage,
title={Z-Image: An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer},
author={Z-Image Team},
journal={arXiv preprint arXiv:2511.22699},
year={2025}
}
@article{liu2025decoupled,
title={Decoupled DMD: CFG Augmentation as the Spear, Distribution Matching as the Shield},
author={Dongyang Liu and Peng Gao and David Liu and Ruoyi Du and Zhen Li and Qilong Wu and Xin Jin and Sihan Cao and Shifeng Zhang and Hongsheng Li and Steven Hoi},
journal={arXiv preprint arXiv:2511.22677},
year={2025}
}
@article{jiang2025distribution,
title={Distribution Matching Distillation Meets Reinforcement Learning},
author={Jiang, Dengyang and Liu, Dongyang and Wang, Zanyi and Wu, Qilong and Jin, Xin and Liu, David and Li, Zhen and Wang, Mengmeng and Gao, Peng and Yang, Harry},
journal={arXiv preprint arXiv:2511.13649},
year={2025}
}
```
## 🤝 We're Hiring!
We're actively looking for **Research Scientists**, **Engineers**, and **Interns** to work on foundational generative models and their applications. Interested candidates please send your resume to: **jingpeng.gp@alibaba-inc.com**
================================================
FILE: batch_inference.py
================================================
"""Batch prompt inference for Z-Image."""
import os
from pathlib import Path
import time
import torch
from inference import ensure_weights
from utils import AttentionBackend, load_from_local_dir, set_attention_backend
from zimage import generate
def read_prompts(path: str) -> list[str]:
"""Read prompts from a text file (one per line, empty lines skipped)."""
prompt_path = Path(path)
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
with prompt_path.open("r", encoding="utf-8") as f:
prompts = [line.strip() for line in f if line.strip()]
if not prompts:
raise ValueError(f"No prompts found in {prompt_path}")
return prompts
PROMPTS = read_prompts(os.environ.get("PROMPTS_FILE", "prompts/prompt1.txt"))
def slugify(text: str, max_len: int = 60) -> str:
"""Create a filesystem-safe slug from the prompt."""
slug = "".join(ch.lower() if ch.isalnum() else "-" for ch in text)
slug = "-".join(part for part in slug.split("-") if part)
return slug[:max_len].rstrip("-") or "prompt"
def select_device() -> str:
"""Choose the best available device without repeating detection logic."""
if torch.cuda.is_available():
print("Chosen device: cuda")
return "cuda"
try:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
print("Chosen device: tpu")
return device
except (ImportError, RuntimeError):
if torch.backends.mps.is_available():
print("Chosen device: mps")
return "mps"
print("Chosen device: cpu")
return "cpu"
def main():
model_path = ensure_weights("ckpts/Z-Image-Turbo")
dtype = torch.bfloat16
compile = False
height = 1024
width = 1024
num_inference_steps = 8
guidance_scale = 0.0
attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
device = select_device()
components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
AttentionBackend.print_available_backends()
set_attention_backend(attn_backend)
print(f"Chosen attention backend: {attn_backend}")
for idx, prompt in enumerate(PROMPTS, start=1):
output_path = output_dir / f"prompt-{idx:02d}-{slugify(prompt)}.png"
seed = 42 + idx - 1
generator = torch.Generator(device).manual_seed(seed)
start_time = time.time()
images = generate(
prompt=prompt,
**components,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
elapsed = time.time() - start_time
images[0].save(output_path)
print(f"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds")
print("Done.")
if __name__ == "__main__":
main()
================================================
FILE: inference.py
================================================
"""Z-Image PyTorch Native Inference."""
import os
import time
import warnings
import torch
warnings.filterwarnings("ignore")
from utils import AttentionBackend, ensure_model_weights, load_from_local_dir, set_attention_backend
from zimage import generate
def main():
model_path = ensure_model_weights("ckpts/Z-Image-Turbo", verify=False) # True to verify with md5
dtype = torch.bfloat16
compile = False # default False for compatibility
output_path = "example.png"
height = 1024
width = 1024
num_inference_steps = 8
guidance_scale = 0.0
seed = 42
attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
prompt = (
"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. "
"Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. "
"Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, "
"silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
)
# Device selection priority: cuda -> tpu -> mps -> cpu
if torch.cuda.is_available():
device = "cuda"
print("Chosen device: cuda")
else:
try:
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
print("Chosen device: tpu")
except (ImportError, RuntimeError):
if torch.backends.mps.is_available():
device = "mps"
print("Chosen device: mps")
else:
device = "cpu"
print("Chosen device: cpu")
# Load models
components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
AttentionBackend.print_available_backends()
set_attention_backend(attn_backend)
print(f"Chosen attention backend: {attn_backend}")
# Gen an image
start_time = time.time()
images = generate(
prompt=prompt,
**components,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device).manual_seed(seed),
)
end_time = time.time()
print(f"Time taken: {end_time - start_time:.2f} seconds")
images[0].save(output_path)
### !! For best speed performance, recommend to use `_flash_3` backend and set `compile=True`
### This would give you sub-second generation speed on Hopper GPU (H100/H200/H800) after warm-up
if __name__ == "__main__":
main()
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "zimage-native"
version = "0.1.0"
description = "Z-Image PyTorch Native Implementation"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"transformers>=4.51.0",
"safetensors",
"loguru",
"pillow",
"accelerate",
"huggingface_hub>=0.25.0"
]
[project.optional-dependencies]
dev = [
"black",
"isort",
"ruff"
]
[tool.setuptools.packages.find]
where = ["src"]
================================================
FILE: src/__init__.py
================================================
"""Z-Image Native Implementation."""
from .utils import load_from_local_dir
from .zimage import ZImageTransformer2DModel, generate
__version__ = "0.1.0"
__all__ = [
"ZImageTransformer2DModel",
"generate",
"load_from_local_dir",
]
================================================
FILE: src/config/__init__.py
================================================
"""Z-Image Configuration."""
from .inference import (
DEFAULT_CFG_TRUNCATION,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_HEIGHT,
DEFAULT_INFERENCE_STEPS,
DEFAULT_MAX_SEQUENCE_LENGTH,
DEFAULT_WIDTH,
)
from .model import (
ADALN_EMBED_DIM,
BASE_IMAGE_SEQ_LEN,
BASE_SHIFT,
BYTES_PER_GB,
DEFAULT_LOAD_DEVICE,
DEFAULT_LOAD_DTYPE_STR,
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
DEFAULT_SCHEDULER_SHIFT,
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
DEFAULT_TRANSFORMER_DIM,
DEFAULT_TRANSFORMER_F_PATCH_SIZE,
DEFAULT_TRANSFORMER_IN_CHANNELS,
DEFAULT_TRANSFORMER_N_HEADS,
DEFAULT_TRANSFORMER_N_KV_HEADS,
DEFAULT_TRANSFORMER_N_LAYERS,
DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
DEFAULT_TRANSFORMER_NORM_EPS,
DEFAULT_TRANSFORMER_PATCH_SIZE,
DEFAULT_TRANSFORMER_QK_NORM,
DEFAULT_TRANSFORMER_T_SCALE,
DEFAULT_VAE_IN_CHANNELS,
DEFAULT_VAE_LATENT_CHANNELS,
DEFAULT_VAE_NORM_NUM_GROUPS,
DEFAULT_VAE_OUT_CHANNELS,
DEFAULT_VAE_SCALE_FACTOR,
DEFAULT_VAE_SCALING_FACTOR,
FREQUENCY_EMBEDDING_SIZE,
MAX_IMAGE_SEQ_LEN,
MAX_PERIOD,
MAX_SHIFT,
ROPE_AXES_DIMS,
ROPE_AXES_LENS,
ROPE_THETA,
SEQ_MULTI_OF,
)
__all__ = [
"ADALN_EMBED_DIM",
"SEQ_MULTI_OF",
"ROPE_THETA",
"ROPE_AXES_DIMS",
"ROPE_AXES_LENS",
"FREQUENCY_EMBEDDING_SIZE",
"MAX_PERIOD",
"BASE_IMAGE_SEQ_LEN",
"MAX_IMAGE_SEQ_LEN",
"BASE_SHIFT",
"MAX_SHIFT",
"DEFAULT_VAE_SCALE_FACTOR",
"DEFAULT_VAE_IN_CHANNELS",
"DEFAULT_VAE_OUT_CHANNELS",
"DEFAULT_VAE_LATENT_CHANNELS",
"DEFAULT_VAE_NORM_NUM_GROUPS",
"DEFAULT_VAE_SCALING_FACTOR",
"DEFAULT_TRANSFORMER_PATCH_SIZE",
"DEFAULT_TRANSFORMER_F_PATCH_SIZE",
"DEFAULT_TRANSFORMER_IN_CHANNELS",
"DEFAULT_TRANSFORMER_DIM",
"DEFAULT_TRANSFORMER_N_LAYERS",
"DEFAULT_TRANSFORMER_N_REFINER_LAYERS",
"DEFAULT_TRANSFORMER_N_HEADS",
"DEFAULT_TRANSFORMER_N_KV_HEADS",
"DEFAULT_TRANSFORMER_NORM_EPS",
"DEFAULT_TRANSFORMER_QK_NORM",
"DEFAULT_TRANSFORMER_CAP_FEAT_DIM",
"DEFAULT_TRANSFORMER_T_SCALE",
"DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS",
"DEFAULT_SCHEDULER_SHIFT",
"DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING",
"DEFAULT_LOAD_DEVICE",
"DEFAULT_LOAD_DTYPE_STR",
"BYTES_PER_GB",
"DEFAULT_HEIGHT",
"DEFAULT_WIDTH",
"DEFAULT_INFERENCE_STEPS",
"DEFAULT_GUIDANCE_SCALE",
"DEFAULT_CFG_TRUNCATION",
"DEFAULT_MAX_SEQUENCE_LENGTH",
]
================================================
FILE: src/config/inference.py
================================================
"""Inference-specific configuration for Z-Image."""
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
DEFAULT_INFERENCE_STEPS = 8
DEFAULT_GUIDANCE_SCALE = 0.0
DEFAULT_CFG_TRUNCATION = 1.0
DEFAULT_MAX_SEQUENCE_LENGTH = 512
================================================
FILE: src/config/manifests/README.md
================================================
# Model Manifests
This directory contains manifest files for different Z-Image model variants.
## Purpose
Manifest files list all required files for each model, optionally with MD5 checksums for integrity verification.
## File Naming Convention
- `z-image-turbo.txt` - Z-Image Turbo model
- Custom models: `{model-name}.txt`
## Format
### Standard Format (with MD5 - Recommended)
```txt
# Z-Image Model Manifest
# Format:
# Generated automatically - DO NOT edit manually
5e3226ed72a9a4a080f2a4ca78b98ddc model_index.json
ca682fcc6c5a94cf726b7187e64b9411 scheduler/scheduler_config.json
1e97eb35d9d0b6aa60c58a8df8d7d99a text_encoder/config.json
30b85686b9a9b002e012494fadc027cb text_encoder/model-00001-of-00003.safetensors
...
```
**Verification Behavior:**
- `verify=False`: Default, only checks file existence, ignores MD5 (fast)
- `verify=True`: Checks existence AND verifies MD5 checksums (thorough)
## Usage
The manifest file is automatically selected based on the model directory name:
```python
# Auto-detects manifest from "Z-Image-Turbo" -> uses z-image-turbo.txt
model_path = ensure_model_weights("ckpts/Z-Image-Turbo")
# Explicit manifest
model_path = ensure_model_weights("ckpts/Z-Image-Turbo", manifest_name="z-image-turbo.txt")
```
## Generating Manifests
Use the provided tool to generate manifests:
```bash
# Generate with MD5 checksums (auto-saves to this directory)
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo
# Generate without checksums (faster, not recommended)
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums
# With verbose output
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --verbose
# Custom output path
python -m src.tools.generate_manifest ckpts/Z-Image-Turbo --output custom.txt
```
## Available Manifests
- **z-image-turbo.txt** - Z-Image Turbo model
================================================
FILE: src/config/manifests/z-image-turbo.txt
================================================
# Z-Image Model Manifest
# Format:
# Generated automatically - DO NOT edit manually
5e3226ed72a9a4a080f2a4ca78b98ddc model_index.json
ca682fcc6c5a94cf726b7187e64b9411 scheduler/scheduler_config.json
1e97eb35d9d0b6aa60c58a8df8d7d99a text_encoder/config.json
30b85686b9a9b002e012494fadc027cb text_encoder/model-00001-of-00003.safetensors
e6a24ea164404a01ad2800dbae4e1a13 text_encoder/model-00002-of-00003.safetensors
09e190ed15ff14795b6277e023cfcb2d text_encoder/model-00003-of-00003.safetensors
589f5395156900f49d617aee8a8d8708 text_encoder/model.safetensors.index.json
6423133b9cc1a2077b57822c30c211aa tokenizer/tokenizer.json
b06e103ac555ec4b51266078b518c0f0 tokenizer/tokenizer_config.json
baed87136fe5f848e24b072f99856cc3 transformer/config.json
54889d0dd179b4fa2fd7bd0e487d856e transformer/diffusion_pytorch_model-00001-of-00003.safetensors
fe81e804658d345323512c63224b0604 transformer/diffusion_pytorch_model-00002-of-00003.safetensors
4e074e09129f98ad840414951f122feb transformer/diffusion_pytorch_model-00003-of-00003.safetensors
76d788eb0d42c59cc8f8ec007db639aa transformer/diffusion_pytorch_model.safetensors.index.json
ba9e2980c8630b4abccc643bc9f4a542 vae/config.json
6f83de55cb720c7fae051b14528577bf vae/diffusion_pytorch_model.safetensors
================================================
FILE: src/config/model.py
================================================
"""Model configuration constants for Z-Image."""
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
ROPE_THETA = 256.0
ROPE_AXES_DIMS = [32, 48, 48]
ROPE_AXES_LENS = [1536, 512, 512]
FREQUENCY_EMBEDDING_SIZE = 256
MAX_PERIOD = 10000
BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
DEFAULT_VAE_SCALE_FACTOR = 8
DEFAULT_VAE_IN_CHANNELS = 3
DEFAULT_VAE_OUT_CHANNELS = 3
DEFAULT_VAE_LATENT_CHANNELS = 4
DEFAULT_VAE_NORM_NUM_GROUPS = 32
DEFAULT_VAE_SCALING_FACTOR = 0.18215
DEFAULT_TRANSFORMER_PATCH_SIZE = (2,)
DEFAULT_TRANSFORMER_F_PATCH_SIZE = (1,)
DEFAULT_TRANSFORMER_IN_CHANNELS = 16
DEFAULT_TRANSFORMER_DIM = 3840
DEFAULT_TRANSFORMER_N_LAYERS = 30
DEFAULT_TRANSFORMER_N_REFINER_LAYERS = 2
DEFAULT_TRANSFORMER_N_HEADS = 30
DEFAULT_TRANSFORMER_N_KV_HEADS = 30
DEFAULT_TRANSFORMER_NORM_EPS = 1e-5
DEFAULT_TRANSFORMER_QK_NORM = True
DEFAULT_TRANSFORMER_CAP_FEAT_DIM = 2560
DEFAULT_TRANSFORMER_T_SCALE = 1000.0
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS = 1000
DEFAULT_SCHEDULER_SHIFT = 3.0
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING = False
DEFAULT_LOAD_DEVICE = "cuda"
DEFAULT_LOAD_DTYPE_STR = "bfloat16"
BYTES_PER_GB = 2**30
================================================
FILE: src/tools/__init__.py
================================================
"""Tools for Z-Image model management."""
from .generate_manifest import compute_md5, get_essential_files
__all__ = [
"compute_md5",
"get_essential_files",
]
================================================
FILE: src/tools/generate_manifest.py
================================================
#!/usr/bin/env python3
"""Generate manifest file with MD5 checksums for model weights.
Usage:
python -m tools.generate_manifest ckpts/Z-Image-Turbo
python -m tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums # Only list files
"""
import argparse
import hashlib
from pathlib import Path
from typing import List
def compute_md5(file_path: Path, chunk_size: int = 8192) -> str:
"""Compute MD5 hash of a file."""
md5_hash = hashlib.md5()
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def get_essential_files(model_dir: Path) -> List[Path]:
"""Get list of essential model files."""
essential_patterns = [
"model_index.json",
"transformer/config.json",
"transformer/*.safetensors*",
"vae/config.json",
"vae/*.safetensors",
"text_encoder/config.json",
"text_encoder/*.safetensors*",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"scheduler/scheduler_config.json",
]
files = []
for pattern in essential_patterns:
if "*" in pattern:
files.extend(model_dir.glob(pattern))
else:
file_path = model_dir / pattern
if file_path.exists():
files.append(file_path)
return sorted(files)
def main():
parser = argparse.ArgumentParser(description="Generate manifest file for model weights")
parser.add_argument("model_dir", type=str, help="Path to model directory")
parser.add_argument("--output", "-o", type=str, default=None,
help="Output manifest file path (default: auto-detect to config/manifests/)")
parser.add_argument("--no-checksums", action="store_true",
help="Only list files without computing checksums")
parser.add_argument("--verbose", "-v", action="store_true",
help="Print progress")
args = parser.parse_args()
model_dir = Path(args.model_dir)
if not model_dir.exists():
print(f"Error: Model directory not found: {model_dir}")
return 1
# Determine output path
if args.output:
output_file = Path(args.output)
else:
# Auto-detect: save to config/manifests/{model-name}.txt
model_name = model_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
script_dir = Path(__file__).parent
config_dir = script_dir.parent / "config" / "manifests"
config_dir.mkdir(parents=True, exist_ok=True)
output_file = config_dir / f"{model_name}.txt"
# Get essential files
files = get_essential_files(model_dir)
if not files:
print(f"Warning: No essential files found in {model_dir}")
return 1
print(f"Found {len(files)} essential files")
# Generate manifest
with open(output_file, "w", encoding="utf-8") as f:
f.write("# Z-Image Model Manifest\n")
if args.no_checksums:
f.write("# Format: \n")
else:
f.write("# Format: \n")
f.write("# Generated automatically - DO NOT edit manually\n\n")
for file_path in files:
rel_path = file_path.relative_to(model_dir)
if args.no_checksums:
f.write(f"{rel_path}\n")
if args.verbose:
print(f" {rel_path}")
else:
if args.verbose:
print(f"Computing MD5 for {rel_path}...", end=" ", flush=True)
try:
md5_hash = compute_md5(file_path)
f.write(f"{md5_hash} {rel_path}\n")
if args.verbose:
print(f"✓ {md5_hash}")
except Exception as e:
print(f"✗ Error: {e}")
continue
print(f"\n✓ Manifest saved to: {output_file}")
print(f" Total files: {len(files)}")
if not args.no_checksums:
print(f" With MD5 checksums for integrity verification")
return 0
if __name__ == "__main__":
exit(main())
================================================
FILE: src/utils/__init__.py
================================================
"""Utilities for Z-Image."""
from .attention import AttentionBackend, dispatch_attention, set_attention_backend
from .helpers import format_bytes, print_memory_stats, ensure_model_weights
from .loader import load_from_local_dir
__all__ = [
"load_from_local_dir",
"format_bytes",
"print_memory_stats",
"ensure_model_weights",
"AttentionBackend",
"set_attention_backend",
"dispatch_attention",
]
================================================
FILE: src/utils/attention.py
================================================
"""Attention backend utilities for Z-Image."""
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py
from enum import Enum
import functools
import inspect
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from .import_utils import is_flash_attn_3_available, is_flash_attn_available, is_torch_version
_CAN_USE_FLASH_ATTN_2 = is_flash_attn_available()
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
# MPS Flash Attention (Apple Silicon)
try:
import mps_flash_attn
_CAN_USE_MPS_FLASH = mps_flash_attn.is_available()
except ImportError:
_CAN_USE_MPS_FLASH = False
mps_flash_attn = None
_TORCH_VERSION_CHECK = is_torch_version(">=", "2.5.0") # have enable_gqa func call in SPDA
if not _TORCH_VERSION_CHECK:
raise RuntimeError("PyTorch version must be >= 2.5.0 to use this backend.")
else:
print("PyTorch version is >= 2.5.0, check pass.")
if _CAN_USE_FLASH_ATTN_2:
from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
flash_attn_func = None
flash_attn_varlen_func = None
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import (
flash_attn_func as flash_attn_3_func,
flash_attn_varlen_func as flash_attn_3_varlen_func,
)
_flash_attn_3_sig = inspect.signature(flash_attn_3_func)
_FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = "return_attn_probs" in _flash_attn_3_sig.parameters
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
_FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = False
class AttentionBackend(str, Enum):
"""Supported attention backends."""
# Flash Attention
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_3 = "_flash_3"
FLASH_VARLEN_3 = "_flash_varlen_3"
# MPS Flash Attention (Apple Silicon)
MPS_FLASH = "mps_flash"
# PyTorch Native Backends
NATIVE = "native"
NATIVE_FLASH = "_native_flash"
NATIVE_MATH = "_native_math"
@classmethod
def print_available_backends(cls):
available_backends = [backend.value for backend in cls.__members__.values()]
print(f"Available attention backends list: {available_backends}")
# Registry for attention implementations
_ATTENTION_BACKENDS: Dict[str, Callable] = {}
_ATTENTION_CONSTRAINTS: Dict[str, List[Callable]] = {}
def register_backend(name: str, constraints: Optional[List[Callable]] = None):
def decorator(func):
_ATTENTION_BACKENDS[name] = func
_ATTENTION_CONSTRAINTS[name] = constraints or []
return func
return decorator
# --- Checks ---
def _check_device_cuda(query: torch.Tensor, **kwargs) -> None:
if query.device.type != "cuda":
raise ValueError("Query must be on a CUDA device.")
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:
if query.dtype not in (torch.bfloat16, torch.float16):
raise ValueError("Query must be either bfloat16 or float16.")
def _check_device_mps(query: torch.Tensor, **kwargs) -> None:
if query.device.type != "mps":
raise ValueError("Query must be on MPS device.")
def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
if attn_mask is None:
return None
if attn_mask.ndim == 2:
attn_mask = attn_mask[:, None, None, :]
# Convert bool mask to float additive mask
if attn_mask.dtype == torch.bool:
# NOTE: We skip checking for all-True mask (torch.all) to avoid graph breaks in torch.compile
new_mask = torch.zeros_like(attn_mask, dtype=dtype)
new_mask.masked_fill_(~attn_mask, float("-inf"))
return new_mask
return attn_mask
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""Normalize an attention mask to shape [batch_size, seq_len_k] (bool)."""
if attn_mask.dtype != torch.bool:
# Try to convert float mask back to bool if possible, or assume it's float mask
# For varlen flash attn, we strictly need bool mask indicating valid tokens
if torch.is_floating_point(attn_mask):
return attn_mask > -1 # Assuming -inf is masked
# raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
if attn_mask.ndim == 1:
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
elif attn_mask.ndim == 2:
if attn_mask.size(0) not in [1, batch_size]:
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 3:
attn_mask = attn_mask.any(dim=1)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 4:
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k)
attn_mask = attn_mask.any(dim=(1, 2))
if attn_mask.shape != (batch_size, seq_len_k):
# Fallback reshape
return attn_mask.view(batch_size, seq_len_k)
return attn_mask
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_varlen_without_mask(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
device: Optional[torch.device] = None,
):
# Optimized to avoid Inductor "pointless_cumsum_replacement" crash and remove graph breaks
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
cu_seqlens_k = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_kv
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (seq_len_q, seq_len_kv)
def _prepare_for_flash_attn_varlen_with_mask(
batch_size: int,
seq_len_q: int,
attn_mask: torch.Tensor,
device: Optional[torch.device] = None,
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
# Use arange for Q to avoid Inductor crash
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
max_seqlen_q = seq_len_q
max_seqlen_k = attn_mask.shape[1] # not max().item(), static shape to avoid graph break
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_varlen(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
attn_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> None:
if attn_mask is None:
return _prepare_for_flash_attn_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
return _prepare_for_flash_attn_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
@register_backend(AttentionBackend.FLASH, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
if not _CAN_USE_FLASH_ATTN_2:
raise RuntimeError(
f"Flash Attention backend '{AttentionBackend.FLASH}' is not usable because of missing package."
)
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
return out
@register_backend(AttentionBackend.FLASH_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
if not _CAN_USE_FLASH_ATTN_2:
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN}' requires flash-attn.")
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
query_packed = query.flatten(0, 1)
if attn_mask is not None:
key_valid = []
value_valid = []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
else:
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)
out = flash_attn_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
out = out.unflatten(0, (batch_size, -1))
return out
@register_backend(AttentionBackend.FLASH_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, # Unused in simple FA3 func
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
if not _CAN_USE_FLASH_ATTN_3:
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_3}' requires Flash Attention 3 beta.")
kwargs = {
"q": query,
"k": key,
"v": value,
"softmax_scale": scale,
"causal": is_causal,
}
if _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS:
kwargs["return_attn_probs"] = False
out = flash_attn_3_func(**kwargs)
if isinstance(out, tuple):
out = out[0]
return out
@register_backend(AttentionBackend.FLASH_VARLEN_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
if not _CAN_USE_FLASH_ATTN_3:
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN_3}' requires Flash Attention 3 beta.")
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
query_packed = query.flatten(0, 1)
if attn_mask is not None:
key_valid = []
value_valid = []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
else:
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)
kwargs = {
"q": query_packed,
"k": key_packed,
"v": value_packed,
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_k": cu_seqlens_k,
"max_seqlen_q": max_seqlen_q,
"max_seqlen_k": max_seqlen_k,
"softmax_scale": scale,
"causal": is_causal,
}
supports_return_probs = "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters
if supports_return_probs:
kwargs["return_attn_probs"] = False
out = flash_attn_3_varlen_func(**kwargs)
if isinstance(out, tuple):
out = out[0]
out = out.unflatten(0, (batch_size, -1))
return out
@register_backend(AttentionBackend.MPS_FLASH, constraints=[_check_device_mps, _check_qkv_dtype_bf16_or_fp16])
def _mps_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
"""MPS Flash Attention for Apple Silicon (M1/M2/M3/M4)."""
if not _CAN_USE_MPS_FLASH:
raise RuntimeError(
f"MPS Flash Attention backend '{AttentionBackend.MPS_FLASH}' requires mps-flash-attn package. "
"Install with: pip install mps-flash-attn"
)
# Convert from (B, S, H, D) to (B, H, S, D) for mps-flash-attn
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Convert mask to MFA format (bool, True = masked)
mfa_mask = None
if attn_mask is not None:
mfa_mask = mps_flash_attn.convert_mask(_process_mask(attn_mask, query.dtype))
out = mps_flash_attn.flash_attention(
query, key, value,
is_causal=is_causal,
scale=scale,
attn_mask=mfa_mask,
)
# Convert back to (B, S, H, D)
return out.transpose(1, 2).contiguous()
def _native_attention_wrapper(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
backend_kernel=None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_mask = _process_mask(attn_mask, query.dtype)
if backend_kernel is not None:
with torch.nn.attention.sdpa_kernel(backend_kernel):
out = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)
else:
out = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)
return out.transpose(1, 2).contiguous()
@register_backend(AttentionBackend.NATIVE_FLASH)
def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
return _native_attention_wrapper(
query,
key,
value,
attn_mask=None,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
backend_kernel=torch.nn.attention.SDPBackend.FLASH_ATTENTION,
)
@register_backend(AttentionBackend.NATIVE_MATH)
def _math_attention(*args, **kwargs):
return _native_attention_wrapper(*args, **kwargs, backend_kernel=torch.nn.attention.SDPBackend.MATH)
@register_backend(AttentionBackend.NATIVE)
def _native_attention(*args, **kwargs):
return _native_attention_wrapper(*args, **kwargs, backend_kernel=None)
def dispatch_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
backend: Union[str, AttentionBackend, None] = None,
) -> torch.Tensor:
if isinstance(backend, AttentionBackend):
backend = backend.value
elif backend is None:
backend = AttentionBackend.NATIVE
else:
backend = str(backend)
# Explicit dispatch to avoid dynamo guard issues on global dict
if backend == AttentionBackend.FLASH:
return _flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.FLASH_VARLEN:
return _flash_varlen_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.FLASH_3:
return _flash_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.FLASH_VARLEN_3:
return _flash_varlen_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.MPS_FLASH:
return _mps_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.NATIVE_FLASH:
return _native_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
elif backend == AttentionBackend.NATIVE_MATH:
return _math_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
else:
return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
def set_attention_backend(backend: Union[str, AttentionBackend, None]):
try:
from zimage.transformer import ZImageAttention
if backend is not None:
backend = str(backend)
ZImageAttention._attention_backend = backend
except ImportError:
pass
================================================
FILE: src/utils/helpers.py
================================================
"""Helper utilities for Z-Image."""
import hashlib
import json
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from loguru import logger
import torch
from config import BYTES_PER_GB
def format_bytes(size: float) -> str:
"""
Format bytes to GB string.
Args:
size: Size in bytes
Returns:
Formatted string in GB
"""
n = size / BYTES_PER_GB
return f"{n:.2f} GB"
def print_memory_stats(stage: str) -> None:
"""
Print CUDA memory statistics.
Args:
stage: Description of current stage
"""
if not torch.cuda.is_available():
logger.warning("CUDA not available, skipping memory stats")
return
torch.cuda.synchronize()
allocated = torch.cuda.max_memory_allocated()
reserved = torch.cuda.max_memory_reserved()
current_allocated = torch.cuda.memory_allocated()
current_reserved = torch.cuda.memory_reserved()
logger.info(f"[{stage}] Memory Stats:")
logger.info(f" Current Allocated: {format_bytes(current_allocated)}")
logger.info(f" Current Reserved: {format_bytes(current_reserved)}")
logger.info(f" Peak Allocated: {format_bytes(allocated)}")
logger.info(f" Peak Reserved: {format_bytes(reserved)}")
def compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:
"""Compute MD5 hash of a file."""
md5_hash = hashlib.md5()
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:
"""Load manifest file. Returns dict mapping file paths to MD5 hashes (or None)."""
manifest = {}
if not manifest_file.exists():
return manifest
with open(manifest_file, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) == 1:
# Only file path, no checksum
file_path = parts[0]
manifest[file_path] = None
elif len(parts) == 2:
# File path with checksum
if len(parts[0]) == 32 and all(c in '0123456789abcdef' for c in parts[0].lower()):
md5_hash, file_path = parts
else:
file_path, md5_hash = parts
manifest[file_path] = md5_hash
else:
logger.warning(f"Invalid manifest format at line {line_num}: {line}")
continue
return manifest
def verify_file_integrity(
base_dir: Path,
manifest: Dict[str, Optional[str]],
verify_checksums: bool = True
) -> Tuple[bool, List[str], List[str]]:
"""
Verify file integrity using a manifest.
Args:
base_dir: Base directory for relative file paths
manifest: Dictionary of relative paths to MD5 hashes (None if no hash provided)
verify_checksums: If True, verify MD5 checksums when available; if False, only check existence
Returns:
Tuple of (all_valid: bool, missing_files: List[str], corrupted_files: List[str])
"""
missing = []
corrupted = []
for rel_path, expected_md5 in manifest.items():
file_path = base_dir / rel_path
if not file_path.exists():
missing.append(rel_path)
continue
# Only verify checksum if requested AND hash is available
if verify_checksums and expected_md5 is not None:
try:
actual_md5 = compute_file_md5(file_path)
if actual_md5 != expected_md5:
corrupted.append(rel_path)
logger.debug(f"Checksum mismatch for {rel_path}: expected {expected_md5}, got {actual_md5}")
except Exception as e:
logger.error(f"Failed to compute checksum for {rel_path}: {e}")
corrupted.append(rel_path)
all_valid = len(missing) == 0 and len(corrupted) == 0
return all_valid, missing, corrupted
def ensure_model_weights(
model_path: str,
repo_id: str = "Tongyi-MAI/Z-Image-Turbo",
verify: bool = False,
manifest_name: Optional[str] = None
) -> Path:
"""
Ensure model weights exist and optionally verify integrity.
Args:
model_path: Path to model directory
repo_id: HuggingFace repo ID for download
verify: If True, verify MD5 checksums; if False, only check existence
manifest_name: Manifest file name in src/config/manifests/ (auto-detect if None)
Returns:
Path to validated model directory
"""
from huggingface_hub import snapshot_download
target_dir = Path(model_path)
# Determine manifest path
if manifest_name:
# Explicitly specified manifest from config/manifests/
manifest_path = Path(__file__).parent.parent / "config" / "manifests" / manifest_name
else:
# Auto-detect
model_name = target_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
config_manifest = Path(__file__).parent.parent / "config" / "manifests" / f"{model_name}.txt"
if config_manifest.exists():
manifest_path = config_manifest
else:
# Fallback
manifest_path = target_dir / "manifest.txt"
manifest = load_manifest(manifest_path)
if not manifest:
logger.warning(f"Manifest file not found: {manifest_path}")
logger.warning("Skipping file verification (assuming model exists)")
if target_dir.exists():
logger.info(f"✓ Model directory exists: {target_dir}")
return target_dir
else:
logger.warning(f"Model directory not found: {target_dir}")
missing_files = ["entire model directory"]
corrupted_files = []
else:
# Count files with checksums
files_with_checksums = sum(1 for v in manifest.values() if v is not None)
if verify and files_with_checksums == 0:
logger.info(f"Verify requested but no checksums in manifest, only checking existence")
elif verify and files_with_checksums > 0:
logger.info(f"Verifying {files_with_checksums} file(s) with MD5 checksums...")
# Verify files
all_valid, missing_files, corrupted_files = verify_file_integrity(
target_dir, manifest, verify_checksums=verify
)
if all_valid:
if verify and files_with_checksums > 0:
logger.success(f"✓ All files verified with MD5 checksums in {target_dir}")
else:
logger.info(f"✓ All {len(manifest)} required files exist in {target_dir}")
return target_dir
# Report missing and corrupted files
if missing_files:
logger.warning(f"Missing {len(missing_files)} file(s):")
for f in missing_files[:10]:
logger.warning(f" - {f}")
if len(missing_files) > 10:
logger.warning(f" ... and {len(missing_files) - 10} more")
if corrupted_files:
logger.error(f"Corrupted {len(corrupted_files)} file(s) (checksum mismatch):")
for f in corrupted_files[:10]:
logger.error(f" - {f}")
if len(corrupted_files) > 10:
logger.error(f" ... and {len(corrupted_files) - 10} more")
# Download model weights
logger.info(f"\nAttempting to download from {repo_id}...")
try:
target_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=str(target_dir),
local_dir_use_symlinks=False,
resume_download=True,
)
logger.success("✓ Download completed")
except Exception as e:
logger.error(f"✗ Download failed: {e}")
logger.info(
f"\nIf you are offline, please manually download from:\n"
f" https://huggingface.co/{repo_id}\n"
f"and place in: {target_dir.absolute()}"
)
raise RuntimeError(f"Failed to download model weights: {e}")
# Verify after download
if manifest:
all_valid, missing_after, corrupted_after = verify_file_integrity(
target_dir, manifest, verify_checksums=verify
)
if not all_valid:
error_msg = []
if missing_after:
error_msg.append(f"Still missing {len(missing_after)} file(s)")
if corrupted_after:
error_msg.append(f"Still corrupted {len(corrupted_after)} file(s)")
raise FileNotFoundError(
f"After download: {', '.join(error_msg)}\n"
f"Please verify the download or manually place files in:\n"
f" {target_dir.absolute()}"
)
logger.success("✓ All model weights validated successfully")
return target_dir
================================================
FILE: src/utils/import_utils.py
================================================
import importlib.util
import torch
def is_flash_attn_available():
return importlib.util.find_spec("flash_attn") is not None
def is_flash_attn_3_available():
return importlib.util.find_spec("flash_attn_interface") is not None
def is_torch_version(operator: str, version: str):
from packaging import version as pversion
torch_version = pversion.parse(torch.__version__)
target_version = pversion.parse(version)
# print(f"torch_version: {torch_version}, target: torch{operator}{target_version}")
if operator == ">":
return torch_version > target_version
elif operator == ">=":
return torch_version >= target_version
elif operator == "==":
return torch_version == target_version
elif operator == "<=":
return torch_version <= target_version
elif operator == "<":
return torch_version < target_version
return False
================================================
FILE: src/utils/loader.py
================================================
"""Model loading utilities for Z-Image components."""
import json
import os
from pathlib import Path
import sys
from typing import Optional, Union
from loguru import logger
from safetensors.torch import load_file
import torch
from transformers import AutoModel, AutoTokenizer
from config import (
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
DEFAULT_SCHEDULER_SHIFT,
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
DEFAULT_TRANSFORMER_DIM,
DEFAULT_TRANSFORMER_F_PATCH_SIZE,
DEFAULT_TRANSFORMER_IN_CHANNELS,
DEFAULT_TRANSFORMER_N_HEADS,
DEFAULT_TRANSFORMER_N_KV_HEADS,
DEFAULT_TRANSFORMER_N_LAYERS,
DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
DEFAULT_TRANSFORMER_NORM_EPS,
DEFAULT_TRANSFORMER_PATCH_SIZE,
DEFAULT_TRANSFORMER_QK_NORM,
DEFAULT_TRANSFORMER_T_SCALE,
DEFAULT_VAE_IN_CHANNELS,
DEFAULT_VAE_LATENT_CHANNELS,
DEFAULT_VAE_NORM_NUM_GROUPS,
DEFAULT_VAE_OUT_CHANNELS,
DEFAULT_VAE_SCALING_FACTOR,
ROPE_AXES_DIMS,
ROPE_AXES_LENS,
ROPE_THETA,
)
from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL
from zimage.scheduler import FlowMatchEulerDiscreteScheduler
DIFFUSERS_AVAILABLE = False
def load_config(config_path: str) -> dict:
with open(config_path, "r") as f:
return json.load(f)
def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict:
"""Load sharded safetensors from a directory."""
weight_dir = Path(weight_dir)
index_files = list(weight_dir.glob("*.safetensors.index.json"))
state_dict = {}
if index_files:
# Load sharded weights
with open(index_files[0], "r") as f:
index = json.load(f)
weight_map = index.get("weight_map", {})
shard_files = set(weight_map.values())
for shard_file in shard_files:
shard_path = weight_dir / shard_file
shard_state = load_file(str(shard_path), device=str(device))
state_dict.update(shard_state)
else:
# Load single safetensors file
safetensors_files = list(weight_dir.glob("*.safetensors"))
if not safetensors_files:
raise FileNotFoundError(f"No safetensors files found in {weight_dir}")
state_dict = load_file(str(safetensors_files[0]), device=str(device))
# Cast to target dtype if specified
if dtype is not None:
state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}
return state_dict
def load_from_local_dir(
model_dir: Union[str, Path],
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
verbose: bool = False,
compile: bool = False,
) -> dict:
"""
Load all Z-Image components from local directory.
Args:
model_dir: Path to model directory
device: Device to load models on
dtype: Data type for model weights
verbose: Whether to display loading logs
compile: Whether to compile transformer and vae with torch.compile
Returns:
Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler
"""
model_dir = Path(model_dir)
sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src"))
from zimage.transformer import ZImageTransformer2DModel
if verbose:
logger.info(f"Loading Z-Image from: {model_dir}")
# DiT
if verbose:
logger.info("Loading DiT...")
transformer_dir = model_dir / "transformer"
config = load_config(str(transformer_dir / "config.json"))
with torch.device("meta"):
transformer = ZImageTransformer2DModel(
all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)),
all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),
in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS),
dim=config.get("dim", DEFAULT_TRANSFORMER_DIM),
n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS),
n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),
n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS),
n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS),
norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS),
qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM),
cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),
rope_theta=config.get("rope_theta", ROPE_THETA),
t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE),
axes_dims=config.get("axes_dims", ROPE_AXES_DIMS),
axes_lens=config.get("axes_lens", ROPE_AXES_LENS),
).to(dtype)
# DiT (weights to CPU then move to GPU to optimize memory)
state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype)
transformer.load_state_dict(state_dict, strict=False, assign=True)
del state_dict
if verbose:
logger.info("Moving DiT to GPU...")
transformer = transformer.to(device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
transformer.eval()
# VAE
if verbose:
logger.info("Loading VAE...")
vae_dir = model_dir / "vae"
vae_config = load_config(str(vae_dir / "config.json"))
vae = LocalAutoencoderKL(
in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS),
out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS),
down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))),
up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))),
block_out_channels=tuple(vae_config.get("block_out_channels", (64,))),
layers_per_block=vae_config.get("layers_per_block", 1),
latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS),
norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS),
scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR),
shift_factor=vae_config.get("shift_factor", None),
use_quant_conv=vae_config.get("use_quant_conv", True),
use_post_quant_conv=vae_config.get("use_post_quant_conv", True),
mid_block_add_attention=vae_config.get("mid_block_add_attention", True),
)
# VAE (fp32 for better precision)
vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu")
vae.load_state_dict(vae_state_dict, strict=False)
del vae_state_dict
vae.to(device=device, dtype=torch.float32)
vae.eval()
torch.cuda.empty_cache()
# Text Encoder
if verbose:
logger.info("Loading Text Encoder...")
text_encoder_dir = model_dir / "text_encoder"
text_encoder = AutoModel.from_pretrained(
str(text_encoder_dir),
# torch_dtype=dtype, # some version use this
dtype=dtype,
trust_remote_code=True,
)
text_encoder.to(device)
text_encoder.eval()
# Tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if verbose:
logger.info("Loading Tokenizer...")
tokenizer_dir = model_dir / "tokenizer"
tokenizer = AutoTokenizer.from_pretrained(
str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),
trust_remote_code=True,
)
# Scheduler
if verbose:
logger.info("Loading Scheduler...")
scheduler_dir = model_dir / "scheduler"
scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json"))
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),
shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT),
use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),
)
if compile:
if verbose:
logger.info("Compiling DiT and VAE...")
transformer = torch.compile(transformer)
vae = torch.compile(vae)
if verbose:
logger.success("All components loaded successfully")
return {
"transformer": transformer,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
}
================================================
FILE: src/zimage/__init__.py
================================================
"""Z-Image PyTorch Native Implementation."""
from .pipeline import generate
from .transformer import ZImageTransformer2DModel
__all__ = [
"ZImageTransformer2DModel",
"generate",
]
================================================
FILE: src/zimage/autoencoder.py
================================================
"""AutoencoderKL implementation compatible with diffusers weights."""
# Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
@dataclass
class AutoencoderKLOutput:
sample: torch.Tensor
class AutoencoderConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def get(self, key, default=None):
return self.__dict__.get(key, default)
def __getattr__(self, name):
return self.__dict__.get(name)
def swish(x):
return x * torch.sigmoid(x)
class ResnetBlock2D(nn.Module):
def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = swish
if self.in_channels != self.out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / 1.0
return output_tensor
class Attention(nn.Module):
def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):
super().__init__()
self.heads = heads
self.in_channels = in_channels
self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.to_q = nn.Linear(in_channels, in_channels)
self.to_k = nn.Linear(in_channels, in_channels)
self.to_v = nn.Linear(in_channels, in_channels)
self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])
def forward(self, hidden_states):
b, c, h, w = hidden_states.shape
residual = hidden_states
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(b, c, -1).transpose(1, 2) # (B, H*W, C)
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
import torch.nn.functional as F
hidden_states = F.scaled_dot_product_attention(query, key, value)
hidden_states = self.to_out[0](hidden_states)
hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)
return residual + hidden_states
class Downsample2D(nn.Module):
def __init__(self, channels, with_conv=True, out_channels=None, padding=1):
super().__init__()
out_channels = out_channels or channels
self.with_conv = with_conv
if with_conv:
self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)
def forward(self, hidden_states):
if self.with_conv:
return self.conv(hidden_states)
else:
return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
class Upsample2D(nn.Module):
def __init__(self, channels, with_conv=True, out_channels=None):
super().__init__()
out_channels = out_channels or channels
self.with_conv = with_conv
if with_conv:
self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
if self.with_conv:
hidden_states = self.conv(hidden_states)
return hidden_states
class DownEncoderBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):
super().__init__()
resnets = []
for i in range(num_layers):
in_c = in_channels if i == 0 else out_channels
resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]
)
else:
self.downsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
pad = (0, 1, 0, 1)
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
hidden_states = downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):
super().__init__()
resnets = []
for i in range(num_layers):
in_c = in_channels if i == 0 else out_channels
resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):
super().__init__()
self.resnets = nn.ModuleList(
[
ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
]
)
self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])
def forward(self, hidden_states):
hidden_states = self.resnets[0](hidden_states)
for attn in self.attentions:
hidden_states = attn(hidden_states)
hidden_states = self.resnets[1](hidden_states)
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
double_z=True,
):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channel = output_channel
output_channel = block_out_channel
is_final_block = i == len(block_out_channels) - 1
block = DownEncoderBlock2D(
input_channel,
output_channel,
num_layers=layers_per_block,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
)
self.down_blocks.append(block)
self.mid_block = UNetMidBlock2D(
block_out_channels[-1],
resnet_groups=norm_num_groups,
)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, x):
x = self.conv_in(x)
for block in self.down_blocks:
x = block(x)
x = self.mid_block(x)
x = self.conv_norm_out(x)
x = self.conv_act(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = UNetMidBlock2D(
block_out_channels[-1],
resnet_groups=norm_num_groups,
)
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, block_out_channel in enumerate(reversed_block_out_channels):
input_channel = output_channel
output_channel = block_out_channel
is_final_block = i == len(block_out_channels) - 1
block = UpDecoderBlock2D(
input_channel,
output_channel,
num_layers=layers_per_block + 1,
resnet_groups=norm_num_groups,
add_upsample=not is_final_block,
)
self.up_blocks.append(block)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv_in(x)
x = self.mid_block(x)
for block in self.up_blocks:
x = block(x)
x = self.conv_norm_out(x)
x = self.conv_act(x)
x = self.conv_out(x)
return x
class AutoencoderKL(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
**kwargs,
):
super().__init__()
self.config = AutoencoderConfig(
in_channels=in_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
latent_channels=latent_channels,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
)
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
double_z=True,
)
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
@property
def dtype(self):
return next(self.parameters()).dtype
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return AutoencoderKLOutput(sample=dec)
================================================
FILE: src/zimage/pipeline.py
================================================
"""Z-Image Pipeline."""
import inspect
from typing import List, Optional, Union
from loguru import logger
import torch
from config import (
BASE_IMAGE_SEQ_LEN,
BASE_SHIFT,
DEFAULT_CFG_TRUNCATION,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_HEIGHT,
DEFAULT_INFERENCE_STEPS,
DEFAULT_MAX_SEQUENCE_LENGTH,
DEFAULT_WIDTH,
MAX_IMAGE_SEQ_LEN,
MAX_SHIFT,
)
def calculate_shift(
image_seq_len,
base_seq_len: int = BASE_IMAGE_SEQ_LEN,
max_seq_len: int = MAX_IMAGE_SEQ_LEN,
base_shift: float = BASE_SHIFT,
max_shift: float = MAX_SHIFT,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(f"The scheduler does not support custom timestep schedules.")
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(f"The scheduler does not support custom sigmas schedules.")
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@torch.no_grad()
def generate(
transformer,
vae,
text_encoder,
tokenizer,
scheduler,
prompt: Union[str, List[str]],
height: int = DEFAULT_HEIGHT,
width: int = DEFAULT_WIDTH,
num_inference_steps: int = DEFAULT_INFERENCE_STEPS,
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
cfg_normalization: bool = False,
cfg_truncation: float = DEFAULT_CFG_TRUNCATION,
max_sequence_length: int = DEFAULT_MAX_SEQUENCE_LENGTH,
output_type: str = "pil",
):
device = next(transformer.parameters()).device
if hasattr(vae, "config") and hasattr(vae.config, "block_out_channels"):
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
else:
vae_scale_factor = 8
vae_scale = vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(f"Height must be divisible by {vae_scale} (got {height}).")
if width % vae_scale != 0:
raise ValueError(f"Width must be divisible by {vae_scale} (got {width}).")
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
else:
batch_size = len(prompt)
do_classifier_free_guidance = guidance_scale > 1.0
logger.info(f"Generating image: {height}x{width}, steps={num_inference_steps}, cfg={guidance_scale}")
formatted_prompts = []
for p in prompt:
messages = [{"role": "user", "content": p}]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
formatted_prompts.append(formatted_prompt)
text_inputs = tokenizer(
formatted_prompts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
prompt_embeds_list = []
for i in range(len(prompt_embeds)):
prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
negative_prompt_embeds_list = []
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
neg_formatted = []
for p in negative_prompt:
messages = [{"role": "user", "content": p}]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
neg_formatted.append(formatted_prompt)
neg_inputs = tokenizer(
neg_formatted,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
neg_input_ids = neg_inputs.input_ids.to(device)
neg_masks = neg_inputs.attention_mask.to(device).bool()
neg_embeds = text_encoder(
input_ids=neg_input_ids,
attention_mask=neg_masks,
output_hidden_states=True,
).hidden_states[-2]
for i in range(len(neg_embeds)):
negative_prompt_embeds_list.append(neg_embeds[i][neg_masks[i]])
if num_images_per_prompt > 1:
prompt_embeds_list = [pe for pe in prompt_embeds_list for _ in range(num_images_per_prompt)]
if do_classifier_free_guidance:
negative_prompt_embeds_list = [
npe for npe in negative_prompt_embeds_list for _ in range(num_images_per_prompt)
]
height_latent = 2 * (int(height) // vae_scale)
width_latent = 2 * (int(width) // vae_scale)
shape = (batch_size * num_images_per_prompt, transformer.in_channels, height_latent, width_latent)
latents = torch.randn(shape, generator=generator, device=device, dtype=torch.float32)
actual_batch_size = batch_size * num_images_per_prompt
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
sigmas=None,
**scheduler_kwargs,
)
logger.info(f"Sampling loop start: {num_inference_steps} steps")
from tqdm import tqdm
# Denoising loop with progress bar
for i, t in enumerate(tqdm(timesteps, desc="Denoising", total=len(timesteps))):
# If current t is 0 and it's the last step, skip computation
if t == 0 and i == len(timesteps) - 1:
logger.debug(f"Step {i+1}/{num_inference_steps} | t: {t.item():.2f} | Skipping last step")
continue
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
t_norm = timestep[0].item()
current_guidance_scale = guidance_scale
if do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1:
if t_norm > cfg_truncation:
current_guidance_scale = 0.0
apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(
transformer.dtype if hasattr(transformer, "dtype") else next(transformer.parameters()).dtype
)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds_list + negative_prompt_embeds_list
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(next(transformer.parameters()).dtype)
prompt_embeds_model_input = prompt_embeds_list
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
if cfg_normalization and float(cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = -noise_pred.squeeze(2)
latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
if output_type == "latent":
return latents
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
latents = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor
image = vae.decode(latents, return_dict=False)[0]
if output_type == "pil":
from PIL import Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(img) for img in image]
return image
================================================
FILE: src/zimage/scheduler.py
================================================
"""FlowMatchEulerDiscreteScheduler implementation."""
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
from dataclasses import dataclass
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@dataclass
class SchedulerOutput:
prev_sample: torch.FloatTensor
class SchedulerConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def get(self, key, default=None):
return self.__dict__.get(key, default)
def __getattr__(self, name):
return self.__dict__.get(name)
class FlowMatchEulerDiscreteScheduler:
"""Euler scheduler for flow matching."""
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
**kwargs,
):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.use_dynamic_shifting = use_dynamic_shifting
self.config = SchedulerConfig(
num_train_timesteps=num_train_timesteps,
shift=shift,
use_dynamic_shifting=use_dynamic_shifting,
)
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self.sigmas = sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self._step_index = None
self._begin_index = None
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
timesteps: Optional[List[float]] = None,
):
passed_timesteps = timesteps
if num_inference_steps is None:
num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
self.num_inference_steps = num_inference_steps
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + 1
)[:-1]
sigmas = timesteps / self.num_train_timesteps
else:
sigmas = np.array(sigmas).astype(np.float32)
if self.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
if passed_timesteps is None:
timesteps = sigmas * self.num_train_timesteps
else:
timesteps = torch.from_numpy(passed_timesteps).to(dtype=torch.float32, device=device)
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self._begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs,
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample at the previous timestep."""
if self._step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32)
sigma_idx = self._step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]
dt = sigma_next - sigma
prev_sample = sample + dt * model_output
self._step_index += 1
prev_sample = prev_sample.to(model_output.dtype)
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def _sigma_to_t(self, sigma):
return sigma * self.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
================================================
FILE: src/zimage/transformer.py
================================================
"""Z-Image Transformer."""
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from config import (
ADALN_EMBED_DIM,
FREQUENCY_EMBEDDING_SIZE,
MAX_PERIOD,
ROPE_AXES_DIMS,
ROPE_AXES_LENS,
ROPE_THETA,
SEQ_MULTI_OF,
)
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=FREQUENCY_EMBEDDING_SIZE):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(mid_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=MAX_PERIOD):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
class ZImageAttention(nn.Module):
_attention_backend = None
def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: bool = True, eps: float = 1e-5):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])
self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
query = query.unflatten(-1, (self.n_heads, -1))
key = key.unflatten(-1, (self.n_kv_heads, -1))
value = value.unflatten(-1, (self.n_kv_heads, -1))
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# Dispatch
from utils.attention import dispatch_attention
hidden_states = dispatch_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = self.to_out[0](hidden_states)
return output
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.modulation = modulation
self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
if modulation:
self.adaLN_modulation = nn.ModuleList([nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)])
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = (
self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c):
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
x = self.linear(x)
return x
class RopeEmbedder:
def __init__(
self,
theta: float = ROPE_THETA,
axes_dims: List[int] = ROPE_AXES_DIMS,
axes_lens: List[int] = ROPE_AXES_LENS,
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens)
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = ROPE_THETA):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
class ZImageTransformer2DModel(nn.Module):
def __init__(
self,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
rope_theta=ROPE_THETA,
t_scale=1000.0,
axes_dims=ROPE_AXES_DIMS,
axes_lens=ROPE_AXES_LENS,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
self.dim = dim
self.n_heads = n_heads
self.rope_theta = rope_theta
self.t_scale = t_scale
assert len(all_patch_size) == len(all_f_patch_size)
all_x_embedder = {}
all_final_layer = {}
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
self.all_final_layer = nn.ModuleDict(all_final_layer)
self.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
self.layers = nn.ModuleList(
[
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
for layer_id in range(n_layers)
]
)
head_dim = dim // n_heads
assert head_dim == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def patchify_and_embed(
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
all_cap_pos_ids = []
all_cap_pad_mask = []
all_cap_feats_out = []
for _, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
cap_padded_pos_ids = self.create_coordinate_grid(
size=(cap_ori_len + cap_padding_len, 1, 1),
start=(1, 0, 0),
device=device,
).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
# pad mask
all_cap_pad_mask.append(
torch.cat(
[
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
if cap_padding_len > 0
else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
)
# padded feature
all_cap_feats_out.append(
torch.cat(
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
dim=0,
)
if cap_padding_len > 0
else cap_feat
)
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padded_pos_ids = torch.cat(
[
image_ori_pos_ids,
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(image_padding_len, 1),
],
dim=0,
)
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
# pad mask
image_pad_mask = torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
all_image_pad_mask.append(
image_pad_mask
if image_padding_len > 0
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
)
# padded feature
image_padded_feat = torch.cat(
[image, image[-1:].repeat(image_padding_len, 1)],
dim=0,
)
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
return (
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask,
all_cap_pad_mask,
)
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
)
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] # same for dynamo compatibility
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
for layer in self.layers:
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
return x, {}