Showing preview only (466K chars total). Download the full file or copy to clipboard to get everything.
Repository: Wan-Video/Wan2.1
Branch: main
Commit: 9737cba9c1c3
Files: 45
Total size: 448.1 KB
Directory structure:
gitextract_u176_3_t/
├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── generate.py
├── gradio/
│ ├── fl2v_14B_singleGPU.py
│ ├── i2v_14B_singleGPU.py
│ ├── t2i_14B_singleGPU.py
│ ├── t2v_1.3B_singleGPU.py
│ ├── t2v_14B_singleGPU.py
│ └── vace.py
├── pyproject.toml
├── requirements.txt
├── tests/
│ ├── README.md
│ └── test.sh
└── wan/
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── first_last_frame2video.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vace_model.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
├── utils/
│ ├── __init__.py
│ ├── fm_solvers.py
│ ├── fm_solvers_unipc.py
│ ├── prompt_extend.py
│ ├── qwen_vl_utils.py
│ ├── utils.py
│ └── vace_processor.py
└── vace.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.*
*.py[cod]
# *.jpg
*.jpeg
# *.png
*.gif
*.bmp
*.mp4
*.mov
*.mkv
*.log
*.zip
*.pt
*.pth
*.ckpt
*.safetensors
*.json
# *.txt
*.backup
*.pkl
*.html
*.pdf
*.whl
cache
__pycache__/
storage/
samples/
!.gitignore
!requirements.txt
.DS_Store
*DS_Store
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
poetry.lock
================================================
FILE: INSTALL.md
================================================
# Installation Guide
## Install with pip
```bash
pip install .
pip install .[dev] # Installe aussi les outils de dev
```
## Install with Poetry
Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
To install all dependencies:
```bash
poetry install
```
### Handling `flash-attn` Installation Issues
If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
#### No-Build-Isolation Installation (Recommended)
```bash
poetry run pip install --upgrade pip setuptools wheel
poetry run pip install flash-attn --no-build-isolation
poetry install
```
#### Install from Git (Alternative)
```bash
poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
```
---
### Running the Model
Once the installation is complete, you can run **Wan2.1** using:
```bash
poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
#### Test
```bash
pytest tests/
```
#### Format
```bash
black .
isort .
```
================================================
FILE: LICENSE.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
.PHONY: format
format:
isort generate.py gradio wan
yapf -i -r *.py generate.py gradio wan
================================================
FILE: README.md
================================================
# Wan2.1
<p align="center">
<img src="assets/logo.png" width="400"/>
<p>
<p align="center">
💜 <a href="https://wan.video"><b>Wan</b></a>    |    🖥️ <a href="https://github.com/Wan-Video/Wan2.1">GitHub</a>    |   🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>   |   🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>   |    📑 <a href="https://arxiv.org/abs/2503.20314">Technical Report</a>    |    📑 <a href="https://wan.video/welcome?spm=a2ty_o02.30011076.0.0.6c9ee41eCcluqg">Blog</a>    |   💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat Group</a>   |    📖 <a href="https://discord.gg/AKNgpMK4Yj">Discord</a>  
<br>
-----
[**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) <be>
In this repository, we present **Wan2.1**, a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. **Wan2.1** offers these key features:
- 👍 **SOTA Performance**: **Wan2.1** consistently outperforms existing open-source models and state-of-the-art commercial solutions across multiple benchmarks.
- 👍 **Supports Consumer-grade GPUs**: The T2V-1.3B model requires only 8.19 GB VRAM, making it compatible with almost all consumer-grade GPUs. It can generate a 5-second 480P video on an RTX 4090 in about 4 minutes (without optimization techniques like quantization). Its performance is even comparable to some closed-source models.
- 👍 **Multiple Tasks**: **Wan2.1** excels in Text-to-Video, Image-to-Video, Video Editing, Text-to-Image, and Video-to-Audio, advancing the field of video generation.
- 👍 **Visual Text Generation**: **Wan2.1** is the first video model capable of generating both Chinese and English text, featuring robust text generation that enhances its practical applications.
- 👍 **Powerful Video VAE**: **Wan-VAE** delivers exceptional efficiency and performance, encoding and decoding 1080P videos of any length while preserving temporal information, making it an ideal foundation for video and image generation.
## Video Demos
<div align="center">
<video src="https://github.com/user-attachments/assets/4aca6063-60bf-4953-bfb7-e265053f49ef" width="70%" poster=""> </video>
</div>
## 🔥 Latest News!!
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
* Feb 27, 2025: 👋 **Wan2.1** has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
* Feb 25, 2025: 👋 We've released the inference code and weights of **Wan2.1**.
## Community Works
If your work has improved **Wan2.1** and you would like more people to see it, please inform us.
- [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model base on **Wan2.1** that achieves minute-scale, high-quality video synthesis at 19.5 FPS on a single H100 GPU (about 10 FPS on a single Ascend NPU) —without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Visit their [webpage](https://pku-yuangroup.github.io/Helios-Page/) for more details.
- [Video-As-Prompt](https://github.com/bytedance/Video-As-Prompt), the first unified semantic-controlled video generation model based on **Wan2.1-14B-I2V** with a Mixture-of-Transformers architecture and in-context controls (e.g., concept, style, motion, camera). Refer to the [project page](https://bytedance.github.io/Video-As-Prompt/) for more examples.
- [LightX2V](https://github.com/ModelTC/LightX2V), a lightweight and efficient video generation framework that integrates **Wan2.1** and **Wan2.2**, supports multiple engineering acceleration techniques for fast inference, which can run on RTX 5090 and RTX 4060 (8GB VRAM).
- [DriVerse](https://github.com/shalfun/DriVerse), an autonomous driving world model based on **Wan2.1-14B-I2V**, generates future driving videos conditioned on any scene frame and given trajectory. Refer to the [project page](https://github.com/shalfun/DriVerse/tree/main) for more examples.
- [Training-Free-WAN-Editing](https://github.com/KyujinHan/Awesome-Training-Free-WAN2.1-Editing), built on **Wan2.1-T2V-1.3B**, allows training-free video editing with image-based training-free methods, such as [FlowEdit](https://arxiv.org/abs/2412.08629) and [FlowAlign](https://arxiv.org/abs/2505.23145).
- [Wan-Move](https://github.com/ali-vilab/Wan-Move), accepted to NeurIPS 2025, a framework that brings **Wan2.1-I2V-14B** to SOTA fine-grained, point-level motion control! Refer to [their project page](https://wan-move.github.io/) for more information.
- [EchoShot](https://github.com/JoHnneyWang/EchoShot), a native multi-shot portrait video generation model based on **Wan2.1-T2V-1.3B**, allows generation of multiple video clips featuring the same character as well as highly flexible content controllability. Refer to [their project page](https://johnneywang.github.io/EchoShot-webpage/) for more information.
- [AniCrafter](https://github.com/MyNiuuu/AniCrafter), a human-centric animation model based on **Wan2.1-14B-I2V**, controls the Video Diffusion Models with 3DGS Avatars to insert and animate anyone into any scene following given motion sequences. Refer to the [project page](https://myniuuu.github.io/AniCrafter) for more examples.
- [HyperMotion](https://vivocameraresearch.github.io/hypermotion/), a human image animation framework based on **Wan2.1**, addresses the challenge of generating complex human body motions in pose-guided animation. Refer to [their website](https://vivocameraresearch.github.io/magictryon/) for more examples.
- [MagicTryOn](https://vivocameraresearch.github.io/magictryon/), a video virtual try-on framework built upon **Wan2.1-14B-I2V**, addresses the limitations of existing models in expressing garment details and maintaining dynamic stability during human motion. Refer to [their website](https://vivocameraresearch.github.io/magictryon/) for more examples.
- [ATI](https://github.com/bytedance/ATI), built on **Wan2.1-I2V-14B**, is a trajectory-based motion-control framework that unifies object, local, and camera movements in video generation. Refer to [their website](https://anytraj.github.io/) for more examples.
- [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on both **Wan2.1-T2V-1.3B** and **Wan2.1-T2V-14B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom).
- [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it!
- [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG.
- [TeaCache](https://github.com/ali-vilab/TeaCache) now supports **Wan2.1** acceleration, capable of increasing speed by approximately 2x. Feel free to give it a try!
- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides more support for **Wan2.1**, including video-to-video, FP8 quantization, VRAM optimization, LoRA training, and more. Please refer to [their examples](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo).
## 📑 Todo List
- Wan2.1 Text-to-Video
- [x] Multi-GPU Inference code of the 14B and 1.3B models
- [x] Checkpoints of the 14B and 1.3B models
- [x] Gradio demo
- [x] ComfyUI integration
- [x] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 Image-to-Video
- [x] Multi-GPU Inference code of the 14B model
- [x] Checkpoints of the 14B model
- [x] Gradio demo
- [x] ComfyUI integration
- [x] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 First-Last-Frame-to-Video
- [x] Multi-GPU Inference code of the 14B model
- [x] Checkpoints of the 14B model
- [x] Gradio demo
- [ ] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 VACE
- [x] Multi-GPU Inference code of the 14B and 1.3B models
- [x] Checkpoints of the 14B and 1.3B models
- [x] Gradio demo
- [x] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
## Quickstart
#### Installation
Clone the repo:
```sh
git clone https://github.com/Wan-Video/Wan2.1.git
cd Wan2.1
```
Install dependencies:
```sh
# Ensure torch >= 2.4.0
pip install -r requirements.txt
```
#### Model Download
| Models | Download Link | Notes |
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
> 💡Note:
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
> * For the first-last frame to video generation, we train our model primarily on Chinese text-video pairs. Therefore, we recommend using Chinese prompt to achieve better results.
Download models using huggingface-cli:
``` sh
pip install "huggingface_hub[cli]"
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
```
Download models using modelscope-cli:
``` sh
pip install modelscope
modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
```
#### Run Text-to-Video Generation
This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P</th>
<th>720P</th>
</tr>
</thead>
<tbody>
<tr>
<td>t2v-14B</td>
<td style="color: green;">✔️</td>
<td style="color: green;">✔️</td>
<td>Wan2.1-T2V-14B</td>
</tr>
<tr>
<td>t2v-1.3B</td>
<td style="color: green;">✔️</td>
<td style="color: red;">❌</td>
<td>Wan2.1-T2V-1.3B</td>
</tr>
</tbody>
</table>
##### (1) Without Prompt Extension
To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.
- Single-GPU inference
``` sh
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage. For example, on an RTX 4090 GPU:
``` sh
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
- Multi-GPU inference using FSDP + xDiT USP
We use FSDP and [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference.
* Ulysess Strategy
If you want to use [`Ulysses`](https://arxiv.org/abs/2309.14509) strategy, you should set `--ulysses_size $GPU_NUMS`. Note that the `num_heads` should be divisible by `ulysses_size` if you wish to use `Ulysess` strategy. For the 1.3B model, the `num_heads` is `12` which can't be divided by 8 (as most multi-GPU machines have 8 GPUs). Therefore, it is recommended to use `Ring Strategy` instead.
* Ring Strategy
If you want to use [`Ring`](https://arxiv.org/pdf/2310.01889) strategy, you should set `--ring_size $GPU_NUMS`. Note that the `sequence length` should be divisible by `ring_size` when using the `Ring` strategy.
Of course, you can also combine the use of `Ulysses` and `Ring` strategies.
``` sh
pip install "xfuser>=0.4.1"
torchrun --nproc_per_node=8 generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
##### (2) Using Prompt Extension
Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension:
- Use the Dashscope API for extension.
- Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).
- Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).
- Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.
- You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:
```sh
DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh'
```
- Using a local model for extension.
- By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.
- For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
- For image-to-video or first-last-frame-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
- Larger models generally provide better extension results but require more GPU memory.
- You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
``` sh
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh'
```
##### (3) Running with Diffusers
You can easily inference **Wan2.1**-T2V using Diffusers with the following command:
``` python
import torch
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.scheduler = scheduler
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=720,
width=1280,
num_frames=81,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
> 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers.
##### (4) Running local gradio
``` sh
cd gradio
# if one uses dashscope’s API for prompt extension
DASH_API_KEY=your_key python t2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir ./Wan2.1-T2V-14B
# if one uses a local model for prompt extension
python t2v_14B_singleGPU.py --prompt_extend_method 'local_qwen' --ckpt_dir ./Wan2.1-T2V-14B
```
#### Run Image-to-Video Generation
Similar to Text-to-Video, Image-to-Video is also divided into processes with and without the prompt extension step. The specific parameters and their corresponding settings are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P</th>
<th>720P</th>
</tr>
</thead>
<tbody>
<tr>
<td>i2v-14B</td>
<td style="color: green;">❌</td>
<td style="color: green;">✔️</td>
<td>Wan2.1-I2V-14B-720P</td>
</tr>
<tr>
<td>i2v-14B</td>
<td style="color: green;">✔️</td>
<td style="color: red;">❌</td>
<td>Wan2.1-T2V-14B-480P</td>
</tr>
</tbody>
</table>
##### (1) Without Prompt Extension
- Single-GPU inference
```sh
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
- Multi-GPU inference using FSDP + xDiT USP
```sh
pip install "xfuser>=0.4.1"
torchrun --nproc_per_node=8 generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
##### (2) Using Prompt Extension
The process of prompt extension can be referenced [here](#2-using-prompt-extention).
Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`:
```
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
Run with remote prompt extension using `dashscope`:
```
DASH_API_KEY=your_key python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
##### (3) Running with Diffusers
You can easily inference **Wan2.1**-I2V using Diffusers with the following command:
``` python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height, width=width,
num_frames=81,
guidance_scale=5.0
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
> 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers.
##### (4) Running local gradio
```sh
cd gradio
# if one only uses 480P model in gradio
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P
# if one only uses 720P model in gradio
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
# if one uses both 480P and 720P models in gradio
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
```
#### Run First-Last-Frame-to-Video Generation
First-Last-Frame-to-Video is also divided into processes with and without the prompt extension step. Currently, only 720P is supported. The specific parameters and corresponding settings are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P</th>
<th>720P</th>
</tr>
</thead>
<tbody>
<tr>
<td>flf2v-14B</td>
<td style="color: green;">❌</td>
<td style="color: green;">✔️</td>
<td>Wan2.1-FLF2V-14B-720P</td>
</tr>
</tbody>
</table>
##### (1) Without Prompt Extension
- Single-GPU inference
```sh
python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
- Multi-GPU inference using FSDP + xDiT USP
```sh
pip install "xfuser>=0.4.1"
torchrun --nproc_per_node=8 generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
##### (2) Using Prompt Extension
The process of prompt extension can be referenced [here](#2-using-prompt-extention).
Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`:
```
python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
Run with remote prompt extension using `dashscope`:
```
DASH_API_KEY=your_key python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
##### (3) Running local gradio
```sh
cd gradio
# use 720P model in gradio
DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-FLF2V-14B-720P
```
#### Run VACE
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P(~81x480x832)</th>
<th>720P(~81x720x1280)</th>
</tr>
</thead>
<tbody>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td>Wan2.1-VACE-14B</td>
</tr>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: red; text-align: center; vertical-align: middle;">❌</td>
<td>Wan2.1-VACE-1.3B</td>
</tr>
</tbody>
</table>
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
The execution process is as follows:
##### (1) Preprocessing
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
##### (2) cli inference
- Single-GPU inference
```sh
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
##### (3) Running local gradio
- Single-GPU inference
```sh
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
```
#### Run Text-to-Image Generation
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
##### (1) Without Prompt Extension
- Single-GPU inference
```sh
python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人'
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir ./Wan2.1-T2V-14B
```
##### (2) With Prompt Extention
- Single-GPU inference
```sh
python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
```
## Manual Evaluation
##### (1) Text-to-Video Evaluation
Through manual evaluation, the results generated after prompt extension are superior to those from both closed-source and open-source models.
<div align="center">
<img src="assets/t2v_res.jpg" alt="" style="width: 80%;" />
</div>
##### (2) Image-to-Video Evaluation
We also conducted extensive manual evaluations to evaluate the performance of the Image-to-Video model, and the results are presented in the table below. The results clearly indicate that **Wan2.1** outperforms both closed-source and open-source models.
<div align="center">
<img src="assets/i2v_res.png" alt="" style="width: 80%;" />
</div>
## Computational Efficiency on Different GPUs
We test the computational efficiency of different **Wan2.1** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
<div align="center">
<img src="assets/comp_effic.png" alt="" style="width: 80%;" />
</div>
> The parameter settings for the tests presented in this table are as follows:
> (1) For the 1.3B model on 8 GPUs, set `--ring_size 8` and `--ulysses_size 1`;
> (2) For the 14B model on 1 GPU, use `--offload_model True`;
> (3) For the 1.3B model on a single 4090 GPU, set `--offload_model True --t5_cpu`;
> (4) For all testings, no prompt extension was applied, meaning `--use_prompt_extend` was not enabled.
> 💡Note: T2V-14B is slower than I2V-14B because the former samples 50 steps while the latter uses 40 steps.
-------
## Introduction of Wan2.1
**Wan2.1** is designed on the mainstream diffusion transformer paradigm, achieving significant advancements in generative capabilities through a series of innovations. These include our novel spatio-temporal variational autoencoder (VAE), scalable training strategies, large-scale data construction, and automated evaluation metrics. Collectively, these contributions enhance the model’s performance and versatility.
##### (1) 3D Variational Autoencoders
We propose a novel 3D causal VAE architecture, termed **Wan-VAE** specifically designed for video generation. By combining multiple strategies, we improve spatio-temporal compression, reduce memory usage, and ensure temporal causality. **Wan-VAE** demonstrates significant advantages in performance efficiency compared to other open-source VAEs. Furthermore, our **Wan-VAE** can encode and decode unlimited-length 1080P videos without losing historical temporal information, making it particularly well-suited for video generation tasks.
<div align="center">
<img src="assets/video_vae_res.jpg" alt="" style="width: 80%;" />
</div>
##### (2) Video Diffusion DiT
**Wan2.1** is designed using the Flow Matching framework within the paradigm of mainstream Diffusion Transformers. Our model's architecture uses the T5 Encoder to encode multilingual text input, with cross-attention in each transformer block embedding the text into the model structure. Additionally, we employ an MLP with a Linear layer and a SiLU layer to process the input time embeddings and predict six modulation parameters individually. This MLP is shared across all transformer blocks, with each block learning a distinct set of biases. Our experimental findings reveal a significant performance improvement with this approach at the same parameter scale.
<div align="center">
<img src="assets/video_dit_arch.jpg" alt="" style="width: 80%;" />
</div>
| Model | Dimension | Input Dimension | Output Dimension | Feedforward Dimension | Frequency Dimension | Number of Heads | Number of Layers |
|--------|-----------|-----------------|------------------|-----------------------|---------------------|-----------------|------------------|
| 1.3B | 1536 | 16 | 16 | 8960 | 256 | 12 | 30 |
| 14B | 5120 | 16 | 16 | 13824 | 256 | 40 | 40 |
##### Data
We curated and deduplicated a candidate dataset comprising a vast amount of image and video data. During the data curation process, we designed a four-step data cleaning process, focusing on fundamental dimensions, visual quality and motion quality. Through the robust data processing pipeline, we can easily obtain high-quality, diverse, and large-scale training sets of images and videos.

##### Comparisons to SOTA
We compared **Wan2.1** with leading open-source and closed-source models to evaluate the performance. Using our carefully designed set of 1,035 internal prompts, we tested across 14 major dimensions and 26 sub-dimensions. We then compute the total score by performing a weighted calculation on the scores of each dimension, utilizing weights derived from human preferences in the matching process. The detailed results are shown in the table below. These results demonstrate our model's superior performance compared to both open-source and closed-source models.

## Citation
If you find our work helpful, please cite us.
```
@article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314},
year={2025}
}
```
## License Agreement
The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt).
## Acknowledgements
We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
## Contact Us
If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!
================================================
FILE: generate.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"flf2v-14B": {
"prompt":
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 50
if "i2v" in args.task:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.")
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None."
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="zh",
choices=["zh", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="[image to video] The image to generate the video from.")
parser.add_argument(
"--first_frame",
type=str,
default=None,
help="[first-last frame to video] The image (first frame) to generate the video from."
)
parser.add_argument(
"--last_frame",
type=str,
default=None,
help="[first-last frame to video] The image (last frame) to generate the video from."
)
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=5.0,
help="Classifier free guidance scale.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "i2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"]
args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input first frame: {args.first_frame}")
logging.info(f"Input last frame: {args.last_frame}")
first_frame = Image.open(args.first_frame).convert("RGB")
last_frame = Image.open(args.last_frame).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=[first_frame, last_frame],
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanFLF2V pipeline.")
wan_flf2v = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
video = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
"src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...")
if rank == 0:
prompt = prompt_expander.forward(args.prompt)
logging.info(
f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.")
wan_vace = wan.WanVace(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
[args.src_video], [args.src_mask], [
None if args.src_ref_images is None else
args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
raise ValueError(f"Unkown task type: {args.task}")
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)
================================================
FILE: gradio/fl2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_flf2v_720P = None
# Button Func
def load_model(value):
global wan_flf2v_720P
if value == '------':
print("No model loaded")
return '------'
if value == '720P':
if args.ckpt_dir_720p is None:
print("Please specify the checkpoint directory for 720P model")
return '------'
if wan_flf2v_720P is not None:
pass
else:
gc.collect()
print("load 14B-720P flf2v model...", end='', flush=True)
cfg = WAN_CONFIGS['flf2v-14B']
wan_flf2v_720P = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_720p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '720P'
return value
def prompt_enc(prompt, img_first, img_last, tar_lang):
print('prompt extend...')
if img_first is None or img_last is None:
print('Please upload the first and last frames')
return prompt
global prompt_expander
prompt_output = prompt_expander(
prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed,
n_prompt):
if resolution == '------':
print(
'Please specify the resolution ckpt dir or specify the resolution')
return None
else:
if resolution == '720P':
global wan_flf2v_720P
video = wan_flf2v_720P.generate(
flf2vid_prompt,
flf2vid_image_first,
flf2vid_image_last,
max_area=MAX_AREA_CONFIGS['720*1280'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
pass
else:
print('Sorry, currently only 720P is supported.')
return None
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (FLF2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
resolution = gr.Dropdown(
label='Resolution',
choices=['------', '720P'],
value='------')
flf2vid_image_first = gr.Image(
type="pil",
label="Upload First Frame",
elem_id="image_upload",
)
flf2vid_image_last = gr.Image(
type="pil",
label="Upload Last Frame",
elem_id="image_upload",
)
flf2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_flf2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
resolution.input(
fn=load_model, inputs=[resolution], outputs=[resolution])
run_p_button.click(
fn=prompt_enc,
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
tar_lang
],
outputs=[flf2vid_prompt])
run_flf2v_button.click(
fn=flf2v_generation,
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=True)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=True, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
================================================
FILE: gradio/i2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_i2v_480P = None
wan_i2v_720P = None
# Button Func
def load_model(value):
global wan_i2v_480P, wan_i2v_720P
if value == '------':
print("No model loaded")
return '------'
if value == '720P':
if args.ckpt_dir_720p is None:
print("Please specify the checkpoint directory for 720P model")
return '------'
if wan_i2v_720P is not None:
pass
else:
del wan_i2v_480P
gc.collect()
wan_i2v_480P = None
print("load 14B-720P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_720P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_720p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '720P'
if value == '480P':
if args.ckpt_dir_480p is None:
print("Please specify the checkpoint directory for 480P model")
return '------'
if wan_i2v_480P is not None:
pass
else:
del wan_i2v_720P
gc.collect()
wan_i2v_720P = None
print("load 14B-480P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_480P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_480p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '480P'
return value
def prompt_enc(prompt, img, tar_lang):
print('prompt extend...')
if img is None:
print('Please upload an image')
return prompt
global prompt_expander
prompt_output = prompt_expander(
prompt, image=img, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
if resolution == '------':
print(
'Please specify at least one resolution ckpt dir or specify the resolution'
)
return None
else:
if resolution == '720P':
global wan_i2v_720P
video = wan_i2v_720P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['720*1280'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
else:
global wan_i2v_480P
video = wan_i2v_480P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['480*832'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (I2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
resolution = gr.Dropdown(
label='Resolution',
choices=['------', '720P', '480P'],
value='------')
img2vid_image = gr.Image(
type="pil",
label="Upload Input Image",
elem_id="image_upload",
)
img2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_i2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
resolution.input(
fn=load_model, inputs=[resolution], outputs=[resolution])
run_p_button.click(
fn=prompt_enc,
inputs=[img2vid_prompt, img2vid_image, tar_lang],
outputs=[img2vid_prompt])
run_i2v_button.click(
fn=i2v_generation,
inputs=[
img2vid_prompt, img2vid_image, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--ckpt_dir_480p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=True)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=True, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
================================================
FILE: gradio/t2i_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image
# Global Var
prompt_expander = None
wan_t2i = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2i
# print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2i.generate(
txt2img_prompt,
size=(W, H),
frame_num=1,
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_image(
tensor=video.squeeze(1)[None],
save_file="example.png",
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.png"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2I-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2img_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2i_button = gr.Button("Generate Image")
with gr.Column():
result_gallery = gr.Image(
label='Generated Image', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2img_prompt, tar_lang],
outputs=[txt2img_prompt])
run_t2i_button.click(
fn=t2i_generation,
inputs=[
txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2i model...", end='', flush=True)
cfg = WAN_CONFIGS['t2i-14B']
wan_t2i = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
================================================
FILE: gradio/t2v_1.3B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-1.3B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'480*832',
'832*480',
'624*624',
'704*544',
'544*704',
],
value='480*832')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=6.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=20,
value=8.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 1.3B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
================================================
FILE: gradio/t2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-14B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
================================================
FILE: gradio/vace.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import datetime
import os
import sys
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self,
cfg,
skip_load=False,
gallery_share=True,
gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
if not skip_load:
if not args.mp:
self.pipe = WanVace(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
else:
self.pipe = WanVaceMP(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
use_usp=True,
ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
</div>
""")
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
self.src_video = gr.Video(
label="src_video",
sources=['upload'],
value=None,
interactive=True)
with gr.Column(scale=1, min_width=0):
self.src_mask = gr.Video(
label="src_mask",
sources=['upload'],
value=None,
interactive=True)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(
label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(
label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(
label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
show_label=False,
placeholder="positive_prompt_input",
elem_id='positive_prompt',
container=True,
autofocus=True,
elem_classes='type_row',
visible=True,
lines=2)
self.negative_prompt = gr.Textbox(
show_label=False,
value=self.pipe.config.sample_neg_prompt,
placeholder="negative_prompt_input",
elem_id='negative_prompt',
container=True,
autofocus=False,
elem_classes='type_row',
visible=True,
interactive=True,
lines=1)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.shift_scale = gr.Slider(
label='shift_scale',
minimum=0.0,
maximum=100.0,
step=1.0,
value=16.0,
interactive=True)
self.sample_steps = gr.Slider(
label='sample_steps',
minimum=1,
maximum=100,
step=1,
value=25,
interactive=True)
self.context_scale = gr.Slider(
label='context_scale',
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
interactive=True)
self.guide_scale = gr.Slider(
label='guide_scale',
minimum=1,
maximum=10,
step=0.5,
value=5.0,
interactive=True)
self.infer_seed = gr.Slider(
minimum=-1, maximum=10000000, value=2025, label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
self.output_height = gr.Textbox(
label='resolutions_height',
# value=480,
value=720,
interactive=True)
self.output_width = gr.Textbox(
label='resolutions_width',
# value=832,
value=1280,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate', value=16, interactive=True)
self.num_frames = gr.Textbox(
label='num_frames', value=81, interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
self.generate_button = gr.Button(
value='Run',
elem_classes='type_row',
elem_id='generate_button',
visible=True)
with gr.Column(scale=1):
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
#
self.output_gallery = gr.Gallery(
label="output_gallery",
value=[],
interactive=False,
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
shift_scale, sample_steps, context_scale, guide_scale,
infer_seed, output_height, output_width, frame_rate,
num_frames):
output_height, output_width, frame_rate, num_frames = int(
output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
if x is not None
]
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
[src_video], [src_mask], [src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
video = self.pipe.generate(
prompt,
src_video,
src_mask,
src_ref_images,
size=(output_width, output_height),
context_scale=context_scale,
shift=shift_scale,
sampling_steps=sample_steps,
guide_scale=guide_scale,
n_prompt=negative_prompt,
seed=infer_seed,
offload_model=True)
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(
video_path,
fps=frame_rate,
codec='libx264',
quality=8,
macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
print(video_path)
except Exception as e:
raise gr.Error(f"Video save error: {e}")
if self.gallery_share:
self.gallery_share_data.add(video_path)
return self.gallery_share_data.get()
else:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [
self.output_gallery, self.src_video, self.src_mask,
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
self.prompt, self.negative_prompt, self.shift_scale,
self.sample_steps, self.context_scale, self.guide_scale,
self.infer_seed, self.output_height, self.output_width,
self.frame_rate, self.num_frames
]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(
self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(
lambda x: self.gallery_share_data.get()
if self.gallery_share else x,
inputs=[self.output_gallery],
outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Argparser for VACE-WAN Demo:\n')
parser.add_argument(
'--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument(
'--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument(
"--mp",
action="store_true",
help="Use Multi-GPUs",
)
parser.add_argument(
"--model_name",
type=str,
default="vace-14B",
choices=list(WAN_CONFIGS.keys()),
help="The model name to run.")
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--ckpt_dir",
type=str,
# default='models/VACE-Wan2.1-1.3B-Preview',
default='models/Wan2.1-VACE-14B/',
help="The path to the checkpoint directory.",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(
server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True,
debug=True)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "wan"
version = "2.1.0"
description = "Wan: Open and Advanced Large-Scale Video Generative Models"
authors = [
{ name = "Wan Team", email = "wan.ai@alibabacloud.com" }
]
license = { file = "LICENSE.txt" }
readme = "README.md"
requires-python = ">=3.10,<4.0"
dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"opencv-python>=4.9.0.80",
"diffusers>=0.31.0",
"transformers>=4.49.0",
"tokenizers>=0.20.3",
"accelerate>=1.1.1",
"tqdm",
"imageio",
"easydict",
"ftfy",
"dashscope",
"imageio-ffmpeg",
"flash_attn",
"gradio>=5.0.0",
"numpy>=1.23.5,<2"
]
[project.optional-dependencies]
dev = [
"pytest",
"black",
"flake8",
"isort",
"mypy",
"huggingface-hub[cli]"
]
[project.urls]
homepage = "https://wanxai.com"
documentation = "https://github.com/Wan-Video/Wan2.1"
repository = "https://github.com/Wan-Video/Wan2.1"
huggingface = "https://huggingface.co/Wan-AI/"
modelscope = "https://modelscope.cn/organization/Wan-AI"
discord = "https://discord.gg/p5XbdQV7"
[tool.setuptools]
packages = ["wan"]
[tool.setuptools.package-data]
"wan" = ["**/*.py"]
[tool.black]
line-length = 88
[tool.isort]
profile = "black"
[tool.mypy]
strict = true
================================================
FILE: requirements.txt
================================================
torch>=2.4.0
torchvision>=0.19.0
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio
easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
gradio>=5.0.0
numpy>=1.23.5,<2
================================================
FILE: tests/README.md
================================================
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
```bash
bash ./test.sh <local model dir> <gpu number>
```
================================================
FILE: tests/test.sh
================================================
#!/bin/bash
if [ "$#" -eq 2 ]; then
MODEL_DIR=$(realpath "$1")
GPUS=$2
else
echo "Usage: $0 <local model dir> <gpu number>"
exit 1
fi
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1
PY_FILE=./generate.py
function t2v_1_3B() {
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function t2v_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function t2i_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function i2v_14B_480p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function i2v_14B_720p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
function vace_1_3B() {
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
vace_1_3B
================================================
FILE: wan/__init__.py
================================================
from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V
from .text2video import WanT2V
from .vace import WanVace, WanVaceMP
================================================
FILE: wan/configs/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from .wan_i2v_14B import i2v_14B
from .wan_t2v_1_3B import t2v_1_3B
from .wan_t2v_14B import t2v_14B
# the config of t2i_14B is the same as t2v_14B
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'
# the config of flf2v_14B is the same as i2v_14B
flf2v_14B = copy.deepcopy(i2v_14B)
flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
WAN_CONFIGS = {
't2v-14B': t2v_14B,
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
'flf2v-14B': flf2v_14B,
'vace-1.3B': t2v_1_3B,
'vace-14B': t2v_14B,
}
SIZE_CONFIGS = {
'720*1280': (720, 1280),
'1280*720': (1280, 720),
'480*832': (480, 832),
'832*480': (832, 480),
'1024*1024': (1024, 1024),
}
MAX_AREA_CONFIGS = {
'720*1280': 720 * 1280,
'1280*720': 1280 * 720,
'480*832': 480 * 832,
'832*480': 832 * 480,
}
SUPPORTED_SIZES = {
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2v-1.3B': ('480*832', '832*480'),
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
'vace-1.3B': ('480*832', '832*480'),
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
}
================================================
FILE: wan/configs/shared_config.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()
# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.text_len = 512
# transformer
wan_shared_cfg.param_dtype = torch.bfloat16
# inference
wan_shared_cfg.num_train_timesteps = 1000
wan_shared_cfg.sample_fps = 16
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
================================================
FILE: wan/configs/wan_i2v_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan I2V 14B ------------------------#
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
i2v_14B.vae_stride = (4, 8, 8)
# transformer
i2v_14B.patch_size = (1, 2, 2)
i2v_14B.dim = 5120
i2v_14B.ffn_dim = 13824
i2v_14B.freq_dim = 256
i2v_14B.num_heads = 40
i2v_14B.num_layers = 40
i2v_14B.window_size = (-1, -1)
i2v_14B.qk_norm = True
i2v_14B.cross_attn_norm = True
i2v_14B.eps = 1e-6
================================================
FILE: wan/configs/wan_t2v_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V 14B ------------------------#
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
t2v_14B.update(wan_shared_cfg)
# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
t2v_14B.vae_stride = (4, 8, 8)
# transformer
t2v_14B.patch_size = (1, 2, 2)
t2v_14B.dim = 5120
t2v_14B.ffn_dim = 13824
t2v_14B.freq_dim = 256
t2v_14B.num_heads = 40
t2v_14B.num_layers = 40
t2v_14B.window_size = (-1, -1)
t2v_14B.qk_norm = True
t2v_14B.cross_attn_norm = True
t2v_14B.eps = 1e-6
================================================
FILE: wan/configs/wan_t2v_1_3B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V 1.3B ------------------------#
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
t2v_1_3B.update(wan_shared_cfg)
# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
t2v_1_3B.vae_stride = (4, 8, 8)
# transformer
t2v_1_3B.patch_size = (1, 2, 2)
t2v_1_3B.dim = 1536
t2v_1_3B.ffn_dim = 8960
t2v_1_3B.freq_dim = 256
t2v_1_3B.num_heads = 12
t2v_1_3B.num_layers = 30
t2v_1_3B.window_size = (-1, -1)
t2v_1_3B.qk_norm = True
t2v_1_3B.cross_attn_norm = True
t2v_1_3B.eps = 1e-6
================================================
FILE: wan/distributed/__init__.py
================================================
================================================
FILE: wan/distributed/fsdp.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
):
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
device_id=device_id,
sync_module_states=sync_module_states)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
================================================
FILE: wan/distributed/xdit_context_parallel.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
# Context Parallel
c = torch.chunk(
c, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
vace_context=None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.model_type != 'vace' and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if self.model_type == 'vace':
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def usp_attn_forward(self,
x,
seq_lens,
grid_sizes,
freqs,
dtype=torch.bfloat16):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# TODO: We should use unpaded q,k,v for attention.
# k_lens = seq_lens // get_sequence_parallel_world_size()
# if k_lens is not None:
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
# TODO: padding after attention.
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
# output
x = x.flatten(2)
x = self.o(x)
return x
================================================
FILE: wan/first_last_frame2video.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanFLF2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.use_usp = use_usp
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
self.clip = CLIPModel(
dtype=config.clip_dtype,
device=self.device,
checkpoint_path=os.path.join(checkpoint_dir,
config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
first_frame,
last_frame,
max_area=720 * 1280,
frame_num=81,
shift=16,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.5,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input first-last frame and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
first_frame (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
last_frame (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
[NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
to match first_frame.
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
self.device)
F = frame_num
first_frame_h, first_frame_w = first_frame.shape[1:]
aspect_ratio = first_frame_h / first_frame_w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
first_frame_h = lat_h * self.vae_stride[1]
first_frame_w = lat_w * self.vae_stride[2]
if first_frame_size != last_frame_size:
# 1. resize
last_frame_resize_ratio = max(
first_frame_size[0] / last_frame_size[0],
first_frame_size[1] / last_frame_size[1])
last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio),
]
# 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1:-1] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual(
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
first_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic').transpose(0, 1),
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
torch.nn.functional.interpolate(
last_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic').transpose(0, 1),
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
================================================
FILE: wan/image2video.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanI2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.use_usp = use_usp
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
self.clip = CLIPModel(
dtype=config.clip_dtype,
device=self.device,
checkpoint_path=os.path.join(checkpoint_dir,
config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
================================================
FILE: wan/modules/__init__.py
================================================
from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vace_model import VaceWanModel
from .vae import WanVAE
__all__ = [
'WanVAE',
'WanModel',
'VaceWanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
'flash_attention',
]
================================================
FILE: wan/modules/attention.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
================================================
FILE: wan/modules/clip.py
================================================
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
'CLIPModel',
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
gitextract_u176_3_t/
├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── generate.py
├── gradio/
│ ├── fl2v_14B_singleGPU.py
│ ├── i2v_14B_singleGPU.py
│ ├── t2i_14B_singleGPU.py
│ ├── t2v_1.3B_singleGPU.py
│ ├── t2v_14B_singleGPU.py
│ └── vace.py
├── pyproject.toml
├── requirements.txt
├── tests/
│ ├── README.md
│ └── test.sh
└── wan/
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── first_last_frame2video.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vace_model.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
├── utils/
│ ├── __init__.py
│ ├── fm_solvers.py
│ ├── fm_solvers_unipc.py
│ ├── prompt_extend.py
│ ├── qwen_vl_utils.py
│ ├── utils.py
│ └── vace_processor.py
└── vace.py
SYMBOL INDEX (331 symbols across 27 files)
FILE: generate.py
function _validate_args (line 64) | def _validate_args(args):
function _parse_args (line 99) | def _parse_args():
function _init_logging (line 254) | def _init_logging(rank):
function generate (line 266) | def generate(args):
FILE: gradio/fl2v_14B_singleGPU.py
function load_model (line 27) | def load_model(value):
function prompt_enc (line 59) | def prompt_enc(prompt, img_first, img_last, tar_lang):
function flf2v_generation (line 73) | def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_...
function gradio_interface (line 113) | def gradio_interface():
function _parse_args (line 212) | def _parse_args():
FILE: gradio/i2v_14B_singleGPU.py
function load_model (line 28) | def load_model(value):
function prompt_enc (line 87) | def prompt_enc(prompt, img, tar_lang):
function i2v_generation (line 101) | def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
function gradio_interface (line 149) | def gradio_interface():
function _parse_args (line 241) | def _parse_args():
FILE: gradio/t2i_14B_singleGPU.py
function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
function t2i_generation (line 35) | def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
function gradio_interface (line 64) | def gradio_interface():
function _parse_args (line 152) | def _parse_args():
FILE: gradio/t2v_1.3B_singleGPU.py
function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
function t2v_generation (line 35) | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
function gradio_interface (line 64) | def gradio_interface():
function _parse_args (line 154) | def _parse_args():
FILE: gradio/t2v_14B_singleGPU.py
function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
function t2v_generation (line 35) | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
function gradio_interface (line 64) | def gradio_interface():
function _parse_args (line 152) | def _parse_args():
FILE: gradio/vace.py
class FixedSizeQueue (line 22) | class FixedSizeQueue:
method __init__ (line 24) | def __init__(self, max_size):
method add (line 28) | def add(self, item):
method get (line 33) | def get(self):
method __repr__ (line 36) | def __repr__(self):
class VACEInference (line 40) | class VACEInference:
method __init__ (line 42) | def __init__(self,
method create_ui (line 70) | def create_ui(self, *args, **kwargs):
method generate (line 211) | def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
method set_callbacks (line 267) | def set_callbacks(self, **kwargs):
FILE: wan/distributed/fsdp.py
function shard_model (line 12) | def shard_model(
function free_model (line 37) | def free_model(model):
FILE: wan/distributed/xdit_context_parallel.py
function pad_freqs (line 14) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 28) | def rope_apply(x, grid_sizes, freqs):
function usp_dit_forward_vace (line 68) | def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
function usp_dit_forward (line 93) | def usp_dit_forward(
function usp_attn_forward (line 183) | def usp_attn_forward(self,
FILE: wan/first_last_frame2video.py
class WanFLF2V (line 32) | class WanFLF2V:
method __init__ (line 34) | def __init__(
method generate (line 133) | def generate(self,
FILE: wan/image2video.py
class WanI2V (line 32) | class WanI2V:
method __init__ (line 34) | def __init__(
method generate (line 133) | def generate(self,
FILE: wan/modules/attention.py
function flash_attention (line 24) | def flash_attention(
function attention (line 133) | def attention(
FILE: wan/modules/clip.py
function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 41) | class QuickGELU(nn.Module):
method forward (line 43) | def forward(self, x):
class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
method forward (line 49) | def forward(self, x):
class SelfAttention (line 53) | class SelfAttention(nn.Module):
method __init__ (line 55) | def __init__(self,
method forward (line 74) | def forward(self, x):
class SwiGLU (line 94) | class SwiGLU(nn.Module):
method __init__ (line 96) | def __init__(self, dim, mid_dim):
method forward (line 106) | def forward(self, x):
class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 146) | def forward(self, x):
class AttentionPool (line 156) | class AttentionPool(nn.Module):
method __init__ (line 158) | def __init__(self,
method forward (line 186) | def forward(self, x):
class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
method __init__ (line 211) | def __init__(self,
method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 305) | def __init__(self, **kwargs):
method forward (line 315) | def forward(self, ids):
class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 330) | def __init__(self,
method forward (line 406) | def forward(self, imgs, txt_ids):
method param_groups (line 418) | def param_groups(self):
function _clip (line 434) | def _clip(pretrained=False,
function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
class CLIPModel (line 501) | class CLIPModel:
method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
method visual (line 527) | def visual(self, videos):
FILE: wan/modules/model.py
function sinusoidal_embedding_1d (line 18) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 32) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 43) | def rope_apply(x, grid_sizes, freqs):
class WanRMSNorm (line 73) | class WanRMSNorm(nn.Module):
method __init__ (line 75) | def __init__(self, dim, eps=1e-5):
method forward (line 81) | def forward(self, x):
method _norm (line 88) | def _norm(self, x):
class WanLayerNorm (line 92) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 94) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 97) | def forward(self, x):
class WanSelfAttention (line 105) | class WanSelfAttention(nn.Module):
method __init__ (line 107) | def __init__(self,
method forward (line 130) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanT2VCrossAttention (line 162) | class WanT2VCrossAttention(WanSelfAttention):
method forward (line 164) | def forward(self, x, context, context_lens):
class WanI2VCrossAttention (line 187) | class WanI2VCrossAttention(WanSelfAttention):
method __init__ (line 189) | def __init__(self,
method forward (line 202) | def forward(self, x, context, context_lens):
class WanAttentionBlock (line 238) | class WanAttentionBlock(nn.Module):
method __init__ (line 240) | def __init__(self,
method forward (line 278) | def forward(
class Head (line 320) | class Head(nn.Module):
method __init__ (line 322) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 337) | def forward(self, x, e):
class MLPProj (line 350) | class MLPProj(torch.nn.Module):
method __init__ (line 352) | def __init__(self, in_dim, out_dim, flf_pos_emb=False):
method forward (line 363) | def forward(self, image_embeds):
class WanModel (line 372) | class WanModel(ModelMixin, ConfigMixin):
method __init__ (line 383) | def __init__(self,
method forward (line 493) | def forward(
method unpatchify (line 584) | def unpatchify(self, x, grid_sizes):
method init_weights (line 609) | def init_weights(self):
FILE: wan/modules/t5.py
function fp16_clamp (line 20) | def fp16_clamp(x):
function init_weights (line 27) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 48) | def forward(self, x):
class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
method forward (line 61) | def forward(self, x):
class T5Attention (line 69) | class T5Attention(nn.Module):
method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 136) | def forward(self, x):
class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
method __init__ (line 146) | def __init__(self,
method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 206) | def forward(self,
class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 233) | def forward(self, lq, lk):
method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 267) | class T5Encoder(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 303) | def forward(self, ids, mask=None):
class T5Decoder (line 315) | class T5Decoder(nn.Module):
method __init__ (line 317) | def __init__(self,
method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 372) | class T5Model(nn.Module):
method __init__ (line 374) | def __init__(self,
method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 415) | def _t5(name,
function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 472) | class T5EncoderModel:
method __init__ (line 474) | def __init__(
method __call__ (line 506) | def __call__(self, texts, device):
FILE: wan/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 75) | def _clean(self, text):
FILE: wan/modules/vace_model.py
class VaceWanAttentionBlock (line 10) | class VaceWanAttentionBlock(WanAttentionBlock):
method __init__ (line 12) | def __init__(self,
method forward (line 33) | def forward(self, c, x, **kwargs):
class BaseWanAttentionBlock (line 42) | class BaseWanAttentionBlock(WanAttentionBlock):
method __init__ (line 44) | def __init__(self,
method forward (line 58) | def forward(self, x, hints, context_scale=1.0, **kwargs):
class VaceWanModel (line 65) | class VaceWanModel(WanModel):
method __init__ (line 68) | def __init__(self,
method forward_vace (line 136) | def forward_vace(self, x, vace_context, seq_len, kwargs):
method forward (line 155) | def forward(
FILE: wan/modules/vae.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 28) | def forward(self, x, cache_x=None):
class RMS_norm (line 39) | class RMS_norm(nn.Module):
method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 51) | def forward(self, x):
class Upsample (line 57) | class Upsample(nn.Upsample):
method forward (line 59) | def forward(self, x):
class Resample (line 66) | class Resample(nn.Module):
method __init__ (line 68) | def __init__(self, dim, mode):
method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 162) | def init_weight(self, conv):
method init_weight2 (line 174) | def init_weight2(self, conv):
class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
method __init__ (line 228) | def __init__(self, dim):
method forward (line 240) | def forward(self, x):
class Encoder3d (line 265) | class Encoder3d(nn.Module):
method __init__ (line 267) | def __init__(self,
method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 369) | class Decoder3d(nn.Module):
method __init__ (line 371) | def __init__(self,
method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 475) | def count_conv3d(model):
class WanVAE_ (line 483) | class WanVAE_(nn.Module):
method __init__ (line 485) | def __init__(self,
method forward (line 510) | def forward(self, x):
method encode (line 516) | def encode(self, x, scale):
method decode (line 544) | def decode(self, z, scale):
method reparameterize (line 570) | def reparameterize(self, mu, log_var):
method sample (line 575) | def sample(self, imgs, deterministic=False):
method clear_cache (line 582) | def clear_cache(self):
function _video_vae (line 592) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
class WanVAE (line 619) | class WanVAE:
method __init__ (line 621) | def __init__(self,
method encode (line 647) | def encode(self, videos):
method decode (line 657) | def decode(self, zs):
FILE: wan/modules/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 27) | def forward(self, x, mask):
class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 66) | def forward(self, x, mask):
class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
method __init__ (line 81) | def __init__(self,
method forward (line 118) | def forward(self, ids):
function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,
FILE: wan/text2video.py
class WanT2V (line 29) | class WanT2V:
method __init__ (line 31) | def __init__(
method generate (line 114) | def generate(self,
FILE: wan/utils/fm_solvers.py
function get_sampling_sigmas (line 24) | def get_sampling_sigmas(sampling_steps, shift):
function retrieve_timesteps (line 31) | def retrieve_timesteps(
class FlowDPMSolverMultistepScheduler (line 71) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 131) | def __init__(
method step_index (line 204) | def step_index(self):
method begin_index (line 211) | def begin_index(self):
method set_begin_index (line 218) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 228) | def set_timesteps(
method _threshold_sample (line 294) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 332) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 335) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 339) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 343) | def convert_model_output(
method dpm_solver_first_order_update (line 417) | def dpm_solver_first_order_update(
method multistep_dpm_solver_second_order_update (line 488) | def multistep_dpm_solver_second_order_update(
method multistep_dpm_solver_third_order_update (line 598) | def multistep_dpm_solver_third_order_update(
method index_for_timestep (line 681) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 695) | def _init_step_index(self, timestep):
method step (line 708) | def step(
method scale_model_input (line 802) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 817) | def add_noise(
method __len__ (line 858) | def __len__(self):
FILE: wan/utils/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 22) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 79) | def __init__(
method step_index (line 137) | def step_index(self):
method begin_index (line 144) | def begin_index(self):
method set_begin_index (line 151) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 162) | def set_timesteps(
method _threshold_sample (line 232) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 271) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 274) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 278) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 281) | def convert_model_output(
method multistep_uni_p_bh_update (line 352) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 488) | def multistep_uni_c_bh_update(
method index_for_timestep (line 630) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 645) | def _init_step_index(self, timestep):
method step (line 657) | def step(self,
method scale_model_input (line 743) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 760) | def add_noise(
method __len__ (line 801) | def __len__(self):
FILE: wan/utils/prompt_extend.py
class PromptOutput (line 153) | class PromptOutput(object):
method add_custom_field (line 160) | def add_custom_field(self, key: str, value) -> None:
class PromptExpander (line 164) | class PromptExpander:
method __init__ (line 166) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
method extend_with_img (line 171) | def extend_with_img(self,
method extend (line 180) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method decide_system_prompt (line 183) | def decide_system_prompt(self, tar_lang="zh", multi_images_input=False):
method __call__ (line 189) | def __call__(self,
class DashScopePromptExpander (line 213) | class DashScopePromptExpander(PromptExpander):
method __init__ (line 215) | def __init__(self,
method extend (line 252) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 288) | def extend_with_img(self,
class QwenPromptExpander (line 364) | class QwenPromptExpander(PromptExpander):
method __init__ (line 373) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
method extend (line 433) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 464) | def extend_with_img(self,
FILE: wan/utils/qwen_vl_utils.py
function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
function smart_resize (line 54) | def smart_resize(height: int,
function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
function smart_nframes (line 133) | def smart_nframes(
function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
function is_decord_available (line 215) | def is_decord_available() -> bool:
function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
function fetch_video (line 274) | def fetch_video(
function extract_vision_info (line 328) | def extract_vision_info(
function process_vision_info (line 344) | def process_vision_info(
FILE: wan/utils/utils.py
function rand_name (line 14) | def rand_name(length=8, suffix=''):
function cache_video (line 23) | def cache_video(tensor,
function cache_image (line 64) | def cache_image(tensor,
function str2bool (line 94) | def str2bool(v):
FILE: wan/utils/vace_processor.py
class VaceImageProcessor (line 9) | class VaceImageProcessor(object):
method __init__ (line 11) | def __init__(self, downsample=None, seq_len=None):
method _pillow_convert (line 15) | def _pillow_convert(self, image, cvt_type='RGB'):
method _load_image (line 30) | def _load_image(self, img_path):
method _resize_crop (line 37) | def _resize_crop(self, img, oh, ow, normalize=True):
method _image_preprocess (line 60) | def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
method load_image (line 63) | def load_image(self, data_key, **kwargs):
method load_image_pair (line 66) | def load_image_pair(self, data_key, data_key2, **kwargs):
method load_image_batch (line 69) | def load_image_batch(self,
class VaceVideoProcessor (line 91) | class VaceVideoProcessor(object):
method __init__ (line 93) | def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
method set_area (line 105) | def set_area(self, area):
method set_seq_len (line 109) | def set_seq_len(self, seq_len):
method resize_crop (line 113) | def resize_crop(video: torch.Tensor, oh: int, ow: int):
method _video_preprocess (line 151) | def _video_preprocess(self, video, oh, ow):
method _get_frameid_bbox_default (line 154) | def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_...
method _get_frameid_bbox_adjust_last (line 187) | def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
method _get_frameid_bbox (line 219) | def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
method load_video (line 227) | def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
method load_video_pair (line 231) | def load_video_pair(self,
method load_video_batch (line 240) | def load_video_batch(self,
function prepare_source (line 274) | def prepare_source(src_video, src_mask, src_ref_images, num_frames, imag...
FILE: wan/vace.py
class WanVace (line 37) | class WanVace(WanT2V):
method __init__ (line 39) | def __init__(
method vace_encode_frames (line 139) | def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
method vace_encode_masks (line 174) | def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
method vace_latent (line 209) | def vace_latent(self, z, m):
method prepare_source (line 212) | def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
method decode_latent (line 280) | def decode_latent(self, zs, ref_images=None, vae=None):
method generate (line 295) | def generate(self,
class WanVaceMP (line 478) | class WanVaceMP(WanVace):
method __init__ (line 480) | def __init__(self,
method dynamic_load (line 512) | def dynamic_load(self):
method transfer_data_to_cuda (line 544) | def transfer_data_to_cuda(self, data, device):
method mp_worker (line 562) | def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
method generate (line 773) | def generate(self,
Condensed preview — 45 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (476K chars).
[
{
"path": ".gitignore",
"chars": 340,
"preview": ".*\n*.py[cod]\n# *.jpg\n*.jpeg\n# *.png\n*.gif\n*.bmp\n*.mp4\n*.mov\n*.mkv\n*.log\n*.zip\n*.pt\n*.pth\n*.ckpt\n*.safetensors\n*.json\n# *"
},
{
"path": "INSTALL.md",
"chars": 1150,
"preview": "# Installation Guide\n\n## Install with pip\n\n```bash\npip install .\npip install .[dev] # Installe aussi les outils de dev\n"
},
{
"path": "LICENSE.txt",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "Makefile",
"chars": 94,
"preview": ".PHONY: format\n\nformat:\n\tisort generate.py gradio wan\n\tyapf -i -r *.py generate.py gradio wan\n"
},
{
"path": "README.md",
"chars": 42825,
"preview": "# Wan2.1\n\n<p align=\"center\">\n <img src=\"assets/logo.png\" width=\"400\"/>\n<p>\n\n<p align=\"center\">\n 💜 <a href=\"https:/"
},
{
"path": "generate.py",
"chars": 21903,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport logging\nimport os\nimport"
},
{
"path": "gradio/fl2v_14B_singleGPU.py",
"chars": 8277,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport gc\nimport os\nimport os.p"
},
{
"path": "gradio/i2v_14B_singleGPU.py",
"chars": 9282,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport gc\nimport os\nimport os.p"
},
{
"path": "gradio/t2i_14B_singleGPU.py",
"chars": 6614,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
},
{
"path": "gradio/t2v_1.3B_singleGPU.py",
"chars": 6598,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
},
{
"path": "gradio/t2v_14B_singleGPU.py",
"chars": 6598,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
},
{
"path": "gradio/vace.py",
"chars": 12925,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nimport argparse\nimport datetime\nimport os\nimp"
},
{
"path": "pyproject.toml",
"chars": 1339,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"wan\"\nversion ="
},
{
"path": "requirements.txt",
"chars": 227,
"preview": "torch>=2.4.0\ntorchvision>=0.19.0\nopencv-python>=4.9.0.80\ndiffusers>=0.31.0\ntransformers>=4.49.0\ntokenizers>=0.20.3\naccel"
},
{
"path": "tests/README.md",
"chars": 216,
"preview": "\nPut all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify"
},
{
"path": "tests/test.sh",
"chars": 5476,
"preview": "#!/bin/bash\n\n\nif [ \"$#\" -eq 2 ]; then\n MODEL_DIR=$(realpath \"$1\")\n GPUS=$2\nelse\n echo \"Usage: $0 <local model dir> <g"
},
{
"path": "wan/__init__.py",
"chars": 189,
"preview": "from . import configs, distributed, modules\nfrom .first_last_frame2video import WanFLF2V\nfrom .image2video import WanI2V"
},
{
"path": "wan/configs/__init__.py",
"chars": 1458,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport copy\nimport os\n\nos.environ['TOKENIZERS_P"
},
{
"path": "wan/configs/shared_config.py",
"chars": 649,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n#--"
},
{
"path": "wan/configs/wan_i2v_14B.py",
"chars": 1035,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
},
{
"path": "wan/configs/wan_t2v_14B.py",
"chars": 742,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_t2v_1_3B.py",
"chars": 759,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/distributed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wan/distributed/fsdp.py",
"chars": 1307,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nfrom functools import partial\n\nimport"
},
{
"path": "wan/distributed/xdit_context_parallel.py",
"chars": 6839,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
},
{
"path": "wan/first_last_frame2video.py",
"chars": 14622,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/image2video.py",
"chars": 13184,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/modules/__init__.py",
"chars": 422,
"preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
},
{
"path": "wan/modules/attention.py",
"chars": 5435,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "wan/modules/clip.py",
"chars": 16848,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "wan/modules/model.py",
"chars": 21289,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.cuda.amp"
},
{
"path": "wan/modules/t5.py",
"chars": 16910,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "wan/modules/tokenizers.py",
"chars": 2431,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "wan/modules/vace_model.py",
"chars": 8281,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nimpor"
},
{
"path": "wan/modules/vae.py",
"chars": 23135,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "wan/modules/xlm_roberta.py",
"chars": 4865,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "wan/text2video.py",
"chars": 10235,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/utils/__init__.py",
"chars": 402,
"preview": "from .fm_solvers import (\n FlowDPMSolverMultistepScheduler,\n get_sampling_sigmas,\n retrieve_timesteps,\n)\nfrom ."
},
{
"path": "wan/utils/fm_solvers.py",
"chars": 40142,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
},
{
"path": "wan/utils/fm_solvers_unipc.py",
"chars": 32557,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
},
{
"path": "wan/utils/prompt_extend.py",
"chars": 39552,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
},
{
"path": "wan/utils/qwen_vl_utils.py",
"chars": 13054,
"preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
},
{
"path": "wan/utils/utils.py",
"chars": 3256,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
},
{
"path": "wan/utils/vace_processor.py",
"chars": 11914,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport numpy as np\nimport torch\nimport torch.nn"
},
{
"path": "wan/vace.py",
"chars": 32116,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
}
]
About this extraction
This page contains the full source code of the Wan-Video/Wan2.1 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 45 files (448.1 KB), approximately 113.2k tokens, and a symbol index with 331 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.